diff --git a/icefall/utils.py b/icefall/utils.py index 4017d9e9e..8bcf075e6 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -59,14 +59,14 @@ TORCH_VERSION = version.parse(torch.__version__) def create_grad_scaler(device="cuda", **kwargs): """ - Creates a GradScaler compatible with both torch < 2.0 and >= 2.0. + Creates a GradScaler compatible with both torch < 2.3.0 and >= 2.3.0. Accepts all kwargs like: enabled, init_scale, growth_factor, etc. /icefall/egs/librispeech/ASR/./zipformer/train.py:1451: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead. """ - if TORCH_VERSION >= version.parse("2.0.0"): + if TORCH_VERSION >= version.parse("2.3.0"): from torch.amp import GradScaler return GradScaler(device=device, **kwargs) @@ -85,7 +85,7 @@ def torch_autocast(device_type="cuda", **kwargs): Please use `torch.amp.autocast('cuda', args...)` instead. with torch.cuda.amp.autocast(enabled=False): """ - if TORCH_VERSION >= version.parse("2.0.0"): + if TORCH_VERSION >= version.parse("2.3.0"): # Use new unified API with torch.amp.autocast(device_type=device_type, **kwargs): yield diff --git a/requirements.txt b/requirements.txt index d97263142..885bf2fc3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ flake8==5.0.4 # cantonese word segment support pycantonese==3.4.0 +packaging