mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
Merge branch 'fix-ci-2' into fix-ci
This commit is contained in:
commit
82af46284f
@ -59,14 +59,14 @@ TORCH_VERSION = version.parse(torch.__version__)
|
||||
|
||||
def create_grad_scaler(device="cuda", **kwargs):
|
||||
"""
|
||||
Creates a GradScaler compatible with both torch < 2.0 and >= 2.0.
|
||||
Creates a GradScaler compatible with both torch < 2.3.0 and >= 2.3.0.
|
||||
Accepts all kwargs like: enabled, init_scale, growth_factor, etc.
|
||||
|
||||
/icefall/egs/librispeech/ASR/./zipformer/train.py:1451: FutureWarning:
|
||||
`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use
|
||||
`torch.amp.GradScaler('cuda', args...)` instead.
|
||||
"""
|
||||
if TORCH_VERSION >= version.parse("2.0.0"):
|
||||
if TORCH_VERSION >= version.parse("2.3.0"):
|
||||
from torch.amp import GradScaler
|
||||
|
||||
return GradScaler(device=device, **kwargs)
|
||||
@ -85,7 +85,7 @@ def torch_autocast(device_type="cuda", **kwargs):
|
||||
Please use `torch.amp.autocast('cuda', args...)` instead.
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
"""
|
||||
if TORCH_VERSION >= version.parse("2.0.0"):
|
||||
if TORCH_VERSION >= version.parse("2.3.0"):
|
||||
# Use new unified API
|
||||
with torch.amp.autocast(device_type=device_type, **kwargs):
|
||||
yield
|
||||
|
@ -21,3 +21,4 @@ flake8==5.0.4
|
||||
|
||||
# cantonese word segment support
|
||||
pycantonese==3.4.0
|
||||
packaging
|
||||
|
Loading…
x
Reference in New Issue
Block a user