small fixes

This commit is contained in:
k2-fsa 2025-07-01 11:14:48 +08:00
parent a91d890552
commit 633eec5445
2 changed files with 4 additions and 3 deletions

View File

@ -59,14 +59,14 @@ TORCH_VERSION = version.parse(torch.__version__)
def create_grad_scaler(device="cuda", **kwargs): def create_grad_scaler(device="cuda", **kwargs):
""" """
Creates a GradScaler compatible with both torch < 2.0 and >= 2.0. Creates a GradScaler compatible with both torch < 2.3.0 and >= 2.3.0.
Accepts all kwargs like: enabled, init_scale, growth_factor, etc. Accepts all kwargs like: enabled, init_scale, growth_factor, etc.
/icefall/egs/librispeech/ASR/./zipformer/train.py:1451: FutureWarning: /icefall/egs/librispeech/ASR/./zipformer/train.py:1451: FutureWarning:
`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use
`torch.amp.GradScaler('cuda', args...)` instead. `torch.amp.GradScaler('cuda', args...)` instead.
""" """
if TORCH_VERSION >= version.parse("2.0.0"): if TORCH_VERSION >= version.parse("2.3.0"):
from torch.amp import GradScaler from torch.amp import GradScaler
return GradScaler(device=device, **kwargs) return GradScaler(device=device, **kwargs)
@ -85,7 +85,7 @@ def torch_autocast(device_type="cuda", **kwargs):
Please use `torch.amp.autocast('cuda', args...)` instead. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
""" """
if TORCH_VERSION >= version.parse("2.0.0"): if TORCH_VERSION >= version.parse("2.3.0"):
# Use new unified API # Use new unified API
with torch.amp.autocast(device_type=device_type, **kwargs): with torch.amp.autocast(device_type=device_type, **kwargs):
yield yield

View File

@ -21,3 +21,4 @@ flake8==5.0.4
# cantonese word segment support # cantonese word segment support
pycantonese==3.4.0 pycantonese==3.4.0
packaging