mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
small fixes
This commit is contained in:
parent
a91d890552
commit
633eec5445
@ -59,14 +59,14 @@ TORCH_VERSION = version.parse(torch.__version__)
|
|||||||
|
|
||||||
def create_grad_scaler(device="cuda", **kwargs):
|
def create_grad_scaler(device="cuda", **kwargs):
|
||||||
"""
|
"""
|
||||||
Creates a GradScaler compatible with both torch < 2.0 and >= 2.0.
|
Creates a GradScaler compatible with both torch < 2.3.0 and >= 2.3.0.
|
||||||
Accepts all kwargs like: enabled, init_scale, growth_factor, etc.
|
Accepts all kwargs like: enabled, init_scale, growth_factor, etc.
|
||||||
|
|
||||||
/icefall/egs/librispeech/ASR/./zipformer/train.py:1451: FutureWarning:
|
/icefall/egs/librispeech/ASR/./zipformer/train.py:1451: FutureWarning:
|
||||||
`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use
|
`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use
|
||||||
`torch.amp.GradScaler('cuda', args...)` instead.
|
`torch.amp.GradScaler('cuda', args...)` instead.
|
||||||
"""
|
"""
|
||||||
if TORCH_VERSION >= version.parse("2.0.0"):
|
if TORCH_VERSION >= version.parse("2.3.0"):
|
||||||
from torch.amp import GradScaler
|
from torch.amp import GradScaler
|
||||||
|
|
||||||
return GradScaler(device=device, **kwargs)
|
return GradScaler(device=device, **kwargs)
|
||||||
@ -85,7 +85,7 @@ def torch_autocast(device_type="cuda", **kwargs):
|
|||||||
Please use `torch.amp.autocast('cuda', args...)` instead.
|
Please use `torch.amp.autocast('cuda', args...)` instead.
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
"""
|
"""
|
||||||
if TORCH_VERSION >= version.parse("2.0.0"):
|
if TORCH_VERSION >= version.parse("2.3.0"):
|
||||||
# Use new unified API
|
# Use new unified API
|
||||||
with torch.amp.autocast(device_type=device_type, **kwargs):
|
with torch.amp.autocast(device_type=device_type, **kwargs):
|
||||||
yield
|
yield
|
||||||
|
@ -21,3 +21,4 @@ flake8==5.0.4
|
|||||||
|
|
||||||
# cantonese word segment support
|
# cantonese word segment support
|
||||||
pycantonese==3.4.0
|
pycantonese==3.4.0
|
||||||
|
packaging
|
||||||
|
Loading…
x
Reference in New Issue
Block a user