diff --git a/.github/workflows/yesno.yml b/.github/workflows/yesno.yml index aef18e31d..a5832df9d 100644 --- a/.github/workflows/yesno.yml +++ b/.github/workflows/yesno.yml @@ -31,8 +31,8 @@ jobs: run: | # outputting for debugging purposes python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" - # MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10") - MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.5.0") + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10") + # MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.5.0") echo "::set-output name=matrix::${MATRIX}" yesno: needs: generate_build_matrix diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py index fa809b768..9060cdb26 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py @@ -79,7 +79,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -638,7 +644,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -912,7 +918,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py index 60f014c48..457b564fe 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py @@ -72,7 +72,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -688,7 +694,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -989,7 +995,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/model.py b/egs/aishell/ASR/pruned_transducer_stateless3/model.py index a4dda0d6d..3b9dad55e 100644 --- a/egs/aishell/ASR/pruned_transducer_stateless3/model.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -184,7 +184,7 @@ class Transducer(nn.Module): lm = simple_lm_proj(decoder_out) am = simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -219,7 +219,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py index 7c23041ca..ad9f40e25 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py @@ -94,7 +94,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -797,7 +803,7 @@ def train_one_epoch( aishell = is_aishell(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1202,7 +1208,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train.py b/egs/aishell/ASR/pruned_transducer_stateless7/train.py index 2dc835f3b..85a51278b 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/train.py @@ -94,6 +94,7 @@ from icefall.utils import ( filter_uneven_sized_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -809,7 +810,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py index 811269989..a07216de8 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -87,6 +87,7 @@ from icefall.utils import ( setup_logger, str2bool, tokenize_by_CJK_char, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -802,7 +803,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1202,7 +1203,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py index f3b0f1e11..a8373d755 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py @@ -81,7 +81,13 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -812,7 +818,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1202,7 +1208,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index d77f8c270..af4d6442e 100755 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -81,6 +81,7 @@ from icefall.utils import ( filter_uneven_sized_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -514,7 +515,7 @@ def compute_validation_loss( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -608,7 +609,7 @@ def train_one_epoch( ) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py index dddfe52fa..0c389db55 100755 --- a/egs/aishell/ASR/zipformer/train.py +++ b/egs/aishell/ASR/zipformer/train.py @@ -95,6 +95,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -910,7 +911,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1302,7 +1303,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/zipformer/train_bbpe.py b/egs/aishell/ASR/zipformer/train_bbpe.py index dbc262c5c..b9d7fe8ad 100755 --- a/egs/aishell/ASR/zipformer/train_bbpe.py +++ b/egs/aishell/ASR/zipformer/train_bbpe.py @@ -92,6 +92,7 @@ from icefall.utils import ( setup_logger, str2bool, tokenize_by_CJK_char, + torch_autocast, ) @@ -495,7 +496,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -895,7 +896,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py index 8c7448d4c..84cd2ffca 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py @@ -90,7 +90,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -734,7 +740,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1062,7 +1068,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py index a354f761e..ab97f8677 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py @@ -83,7 +83,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -727,7 +733,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) # print(batch["supervisions"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1034,7 +1040,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py index 30154291d..172d94862 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py @@ -79,7 +79,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -638,7 +644,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -912,7 +918,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py index 30879d8d2..855aeca12 100755 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py @@ -73,7 +73,13 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -782,7 +788,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1127,7 +1133,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py index d62cdadb7..8922717ef 100755 --- a/egs/ami/ASR/pruned_transducer_stateless7/train.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py @@ -71,7 +71,13 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -773,7 +779,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1134,7 +1140,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ami/SURT/dprnn_zipformer/train.py b/egs/ami/SURT/dprnn_zipformer/train.py index adc6a8495..3572acd04 100755 --- a/egs/ami/SURT/dprnn_zipformer/train.py +++ b/egs/ami/SURT/dprnn_zipformer/train.py @@ -76,7 +76,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -1067,7 +1073,7 @@ def train_one_epoch( batch_size = batch["inputs"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, diff --git a/egs/ami/SURT/dprnn_zipformer/train_adapt.py b/egs/ami/SURT/dprnn_zipformer/train_adapt.py index ac5b0dadc..313a5c46a 100755 --- a/egs/ami/SURT/dprnn_zipformer/train_adapt.py +++ b/egs/ami/SURT/dprnn_zipformer/train_adapt.py @@ -76,7 +76,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -1058,7 +1064,7 @@ def train_one_epoch( batch_size = batch["inputs"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, diff --git a/egs/audioset/AT/zipformer/train.py b/egs/audioset/AT/zipformer/train.py index 67c703364..caf8accb2 100644 --- a/egs/audioset/AT/zipformer/train.py +++ b/egs/audioset/AT/zipformer/train.py @@ -74,6 +74,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -799,7 +800,7 @@ def train_one_epoch( num_samples += batch_size try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1148,7 +1149,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py index 5e98084ec..7a859ff38 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py @@ -88,6 +88,7 @@ from icefall.utils import ( filter_uneven_sized_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -825,7 +826,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1220,7 +1221,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py index 976004eca..fb812b391 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py @@ -90,6 +90,7 @@ from icefall.utils import ( filter_uneven_sized_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -895,7 +896,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1293,7 +1294,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py index 67e1a8133..f1e9b6d43 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py @@ -81,7 +81,13 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -840,7 +846,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1237,7 +1243,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/zipformer/train.py b/egs/commonvoice/ASR/zipformer/train.py index 271014db0..c6940def5 100755 --- a/egs/commonvoice/ASR/zipformer/train.py +++ b/egs/commonvoice/ASR/zipformer/train.py @@ -97,6 +97,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -969,7 +970,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1365,7 +1366,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/zipformer/train_char.py b/egs/commonvoice/ASR/zipformer/train_char.py index 0aa7856cc..f44232c0e 100755 --- a/egs/commonvoice/ASR/zipformer/train_char.py +++ b/egs/commonvoice/ASR/zipformer/train_char.py @@ -97,6 +97,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -604,7 +605,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -784,7 +785,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py index ef7ea9013..5862cd660 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py @@ -83,7 +83,13 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LOG_EPS = math.log(1e-10) @@ -838,7 +844,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1245,7 +1251,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index a7772b62f..56371e59a 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -77,7 +77,13 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -675,7 +681,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -944,7 +950,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py index 4c122effe..8cf8f9fc7 100755 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -97,6 +97,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -958,7 +959,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1317,7 +1318,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py index 39d8fc6cd..2d88b6e55 100755 --- a/egs/gigaspeech/KWS/zipformer/train.py +++ b/egs/gigaspeech/KWS/zipformer/train.py @@ -97,6 +97,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -961,7 +962,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1320,7 +1321,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py index bf50bf5ea..63a38a4cc 100755 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -77,7 +77,13 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -805,7 +811,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1196,7 +1202,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ksponspeech/ASR/zipformer/train.py b/egs/ksponspeech/ASR/zipformer/train.py index 485ea69c9..406749f22 100755 --- a/egs/ksponspeech/ASR/zipformer/train.py +++ b/egs/ksponspeech/ASR/zipformer/train.py @@ -92,6 +92,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -942,7 +943,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1333,7 +1334,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index c4a13b101..112147249 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -82,6 +82,7 @@ from icefall.env import get_env_info from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, encode_supervisions, @@ -676,7 +677,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py index a2f1125ca..baf75b1de 100755 --- a/egs/librispeech/ASR/conformer_ctc3/train.py +++ b/egs/librispeech/ASR/conformer_ctc3/train.py @@ -93,6 +93,7 @@ from icefall.env import get_env_info from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, encode_supervisions, @@ -743,7 +744,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1073,7 +1074,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index ca21bd6bf..be1407302 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -93,7 +93,13 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -772,7 +778,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1071,7 +1077,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 23ddb6bec..f4a980637 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -93,7 +93,13 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -772,7 +778,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1072,7 +1078,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index e7bad7ed8..9f148b348 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -156,7 +156,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -192,7 +192,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index feb81d500..9087adfb9 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -85,6 +85,7 @@ from icefall.utils import ( display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -763,7 +764,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1092,7 +1093,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index 4957d14b1..5aafe10af 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -195,7 +195,7 @@ class Transducer(nn.Module): lm = simple_lm_proj(decoder_out) am = simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -231,7 +231,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 4fc4fa7f8..934c0cfb0 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -88,6 +88,7 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, display_and_save_batch, @@ -848,7 +849,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1247,7 +1248,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index 2c1cef3a3..ea4a682f1 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -80,6 +80,7 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, display_and_save_batch, @@ -793,7 +794,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1136,7 +1137,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py index ca8c28af1..3b6ce9b89 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/model.py +++ b/egs/librispeech/ASR/pruned2_knowledge/model.py @@ -21,7 +21,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -141,7 +141,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -176,7 +176,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 5b595c76c..7e86ceda3 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -13,6 +13,8 @@ from torch import Tensor, nn from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd from torch_scheduled_sampling import sample_combined +from icefall.utils import torch_autocast + # The main exports of this file are the module KnowledgeBaseLookup and the # function create_knowledge_base. @@ -337,7 +339,7 @@ def _test_knowledge_base_lookup_autocast(): for epoch in range(150): for n, (x, y) in enumerate(train_pairs): y_out = m(x) - with torch.cuda.amp.autocast(enabled=True): + with torch_autocast(enabled=True): loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 931341cc4..83d5446e4 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -76,7 +76,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -650,7 +656,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -937,7 +943,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index 2b872f1d5..192e104f7 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -68,7 +68,13 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) def add_model_arguments(parser: argparse.ArgumentParser): @@ -693,7 +699,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1004,7 +1010,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 272d06c37..6a69332aa 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -157,7 +157,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -193,7 +193,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 6c19f2cb0..a33b1f4c1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -89,6 +89,7 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, display_and_save_batch, @@ -759,7 +760,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1067,7 +1068,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index d45f6dadc..fbc4db921 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -195,7 +195,7 @@ class Transducer(nn.Module): lm = simple_lm_proj(decoder_out) am = simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -231,7 +231,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index fdafa5a87..3dab2eb8b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -85,6 +85,7 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, display_and_save_batch, @@ -827,7 +828,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1195,7 +1196,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 875b03f7f..bfc9fc144 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -94,6 +94,7 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, display_and_save_batch, @@ -789,7 +790,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1116,7 +1117,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 66dc5f991..748f4826d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -82,6 +82,7 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, display_and_save_batch, @@ -814,7 +815,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1147,7 +1148,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index daadb70c9..a5d2457f9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -185,7 +185,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -220,7 +220,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index 8f033cb9a..d3e373321 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -96,6 +96,7 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + torch_autocast, display_and_save_batch, setup_logger, str2bool, @@ -781,7 +782,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1108,7 +1109,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index e7546ec45..a530c74ae 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -88,6 +88,7 @@ from icefall.utils import ( filter_uneven_sized_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -903,7 +904,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1319,7 +1320,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index add0e6a18..ed990b689 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import penalize_abs_values_gt -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -150,7 +150,7 @@ class Transducer(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -185,7 +185,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 30a737061..5a317083c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -28,6 +28,8 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Embedding as ScaledEmbedding +from icefall.utils import torch_autocast + class ActivationBalancerFunction(torch.autograd.Function): @staticmethod @@ -289,7 +291,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -669,7 +671,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): def backward(ctx, x_grad: Tensor): (x_orig,) = ctx.saved_tensors with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -867,7 +869,7 @@ class MaxEig(torch.nn.Module): ): return _no_op(x) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): eps = 1.0e-20 orig_x = x x = x.to(torch.float32) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 436ec53b4..0c3852455 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -84,6 +84,7 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, filter_uneven_sized_batch, @@ -809,7 +810,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index cbde2a2e4..ee05627ae 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -44,7 +44,7 @@ from scaling import ( from torch import Tensor, nn from icefall.dist import get_rank -from icefall.utils import is_jit_tracing, make_pad_mask +from icefall.utils import is_jit_tracing, make_pad_mask, torch_autocast class Zipformer(EncoderInterface): @@ -1421,7 +1421,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz = n // num_heads with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) attn_weights_entropy = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py index a6e919e2f..f1ab2a3ec 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py @@ -22,7 +22,7 @@ import torch import torch.nn as nn from encoder_interface import EncoderInterface -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -150,7 +150,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -185,7 +185,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index b35e56abc..88fd43ee4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -89,6 +89,7 @@ from icefall.utils import ( encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -833,7 +834,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1228,7 +1229,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py index 0582b289f..bf0faf9f1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py @@ -22,7 +22,7 @@ import torch import torch.nn as nn from encoder_interface import EncoderInterface -from icefall.utils import add_sos, make_pad_mask +from icefall.utils import add_sos, make_pad_mask, torch_autocast class Transducer(nn.Module): @@ -178,7 +178,7 @@ class Transducer(nn.Module): am = self.simple_am_proj(encoder_out_fr) lm = self.simple_lm_proj(decoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -213,7 +213,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index c2d877a93..6112c4a0b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -85,6 +85,7 @@ from icefall.utils import ( encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -822,7 +823,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1217,7 +1218,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index 8bd00bbef..c6dd8dd4c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -82,7 +82,13 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -810,7 +816,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1224,7 +1230,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index c7e45564f..640d72b67 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -44,7 +44,7 @@ from scaling import ( ) from torch import Tensor, nn -from icefall.utils import make_pad_mask, subsequent_chunk_mask +from icefall.utils import make_pad_mask, subsequent_chunk_mask, torch_autocast def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]: @@ -2408,7 +2408,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz = n // num_heads with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) attn_weights_entropy = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py index da5e144c9..3b5178851 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py @@ -86,7 +86,13 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -866,7 +872,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1320,7 +1326,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py index 39a360796..e06594c27 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -24,7 +24,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import penalize_abs_values_gt -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -172,7 +172,7 @@ class Transducer(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -207,7 +207,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 646f30ca1..32b899f52 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -91,7 +91,13 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -866,7 +872,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1321,7 +1327,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py index 1bfd071de..dea2b096a 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/train.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py @@ -70,6 +70,7 @@ from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import UniqLexicon from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, encode_supervisions, @@ -809,7 +810,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1198,7 +1199,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py index 6d42fe5ae..c143fbd76 100755 --- a/egs/librispeech/ASR/zipformer/finetune.py +++ b/egs/librispeech/ASR/zipformer/finetune.py @@ -100,6 +100,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -1049,7 +1050,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1474,7 +1475,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index f2791e51f..6ef250819 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -25,7 +25,7 @@ from encoder_interface import EncoderInterface from lhotse.dataset import SpecAugment from scaling import ScaledLinear -from icefall.utils import add_sos, make_pad_mask, time_warp +from icefall.utils import add_sos, make_pad_mask, time_warp, torch_autocast class AsrModel(nn.Module): @@ -285,7 +285,7 @@ class AsrModel(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -320,7 +320,7 @@ class AsrModel(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 11375385e..22aa1b1ca 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -26,6 +26,8 @@ import torch.nn as nn from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd +from icefall.utils import torch_autocast + def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: max_value = torch.max(x, y) @@ -308,7 +310,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -761,7 +763,7 @@ class BalancerFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x = x.to(torch.float32) x = x.detach() x.requires_grad = True @@ -1016,7 +1018,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -1355,7 +1357,7 @@ class SwooshLFunction(torch.autograd.Function): coeff = -0.08 - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True @@ -1432,7 +1434,7 @@ class SwooshRFunction(torch.autograd.Function): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index f8864d58b..9e05ca638 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -96,6 +96,7 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, @@ -1101,9 +1102,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast( - enabled=params.use_autocast, dtype=params.dtype - ): + with torch_autocast(enabled=params.use_autocast, dtype=params.dtype): loss, loss_info = compute_loss( params=params, model=model, @@ -1551,9 +1550,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast( - enabled=params.use_autocast, dtype=params.dtype - ): + with torch_autocast(enabled=params.use_autocast, dtype=params.dtype): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2a0ae0129..e83a89400 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -47,6 +47,8 @@ from scaling import ( ) from torch import Tensor, nn +from icefall.utils import torch_autocast + class Zipformer2(EncoderInterface): """ @@ -1873,7 +1875,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py index 3511590da..d744d59d2 100755 --- a/egs/librispeech/ASR/zipformer_adapter/train.py +++ b/egs/librispeech/ASR/zipformer_adapter/train.py @@ -89,6 +89,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -1052,7 +1053,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1498,7 +1499,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_adapter/zipformer.py b/egs/librispeech/ASR/zipformer_adapter/zipformer.py index 8e2dfdd72..8bc163db5 100644 --- a/egs/librispeech/ASR/zipformer_adapter/zipformer.py +++ b/egs/librispeech/ASR/zipformer_adapter/zipformer.py @@ -50,6 +50,8 @@ from scaling import ( ) from torch import Tensor, nn +from icefall.utils import torch_autocast + class Zipformer2(EncoderInterface): """ @@ -1916,7 +1918,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/ASR/zipformer_ctc/train.py b/egs/librispeech/ASR/zipformer_ctc/train.py index 60112a84e..7c772bb3b 100755 --- a/egs/librispeech/ASR/zipformer_ctc/train.py +++ b/egs/librispeech/ASR/zipformer_ctc/train.py @@ -65,7 +65,13 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler] @@ -726,7 +732,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_lora/finetune.py b/egs/librispeech/ASR/zipformer_lora/finetune.py index 3f36f229f..ca9002928 100755 --- a/egs/librispeech/ASR/zipformer_lora/finetune.py +++ b/egs/librispeech/ASR/zipformer_lora/finetune.py @@ -99,6 +99,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -1065,7 +1066,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1507,7 +1508,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py index 8d7aa8027..1347570df 100644 --- a/egs/librispeech/ASR/zipformer_lora/scaling.py +++ b/egs/librispeech/ASR/zipformer_lora/scaling.py @@ -27,6 +27,8 @@ import torch.nn.functional as F from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd +from icefall.utils import torch_autocast + def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: max_value = torch.max(x, y) @@ -307,7 +309,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -863,7 +865,7 @@ class BalancerFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x = x.to(torch.float32) x = x.detach() x.requires_grad = True @@ -1118,7 +1120,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -1457,7 +1459,7 @@ class SwooshLFunction(torch.autograd.Function): coeff = -0.08 - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True @@ -1534,7 +1536,7 @@ class SwooshRFunction(torch.autograd.Function): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True diff --git a/egs/librispeech/ASR/zipformer_lora/train.py b/egs/librispeech/ASR/zipformer_lora/train.py index 9ab214e86..add50fa25 100755 --- a/egs/librispeech/ASR/zipformer_lora/train.py +++ b/egs/librispeech/ASR/zipformer_lora/train.py @@ -97,6 +97,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -947,7 +948,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1352,7 +1353,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_lora/zipformer.py b/egs/librispeech/ASR/zipformer_lora/zipformer.py index 43865609a..b84b1c32a 100644 --- a/egs/librispeech/ASR/zipformer_lora/zipformer.py +++ b/egs/librispeech/ASR/zipformer_lora/zipformer.py @@ -49,6 +49,8 @@ from scaling import ( ) from torch import Tensor, nn +from icefall.utils import torch_autocast + class Zipformer2(EncoderInterface): """ @@ -1905,7 +1907,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py index c1785a328..0b99ab64e 100755 --- a/egs/librispeech/ASR/zipformer_mmi/train.py +++ b/egs/librispeech/ASR/zipformer_mmi/train.py @@ -90,6 +90,7 @@ from icefall.utils import ( encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -744,7 +745,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1138,7 +1139,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/finetune.py b/egs/librispeech/SSL/hubert/finetune.py index 17daa3c9d..0080513f3 100644 --- a/egs/librispeech/SSL/hubert/finetune.py +++ b/egs/librispeech/SSL/hubert/finetune.py @@ -86,6 +86,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -816,7 +817,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1207,7 +1208,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/finetune_ce.py b/egs/librispeech/SSL/hubert/finetune_ce.py index 2723cc770..1ff2b03c0 100644 --- a/egs/librispeech/SSL/hubert/finetune_ce.py +++ b/egs/librispeech/SSL/hubert/finetune_ce.py @@ -81,6 +81,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, @@ -816,7 +817,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1207,7 +1208,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/model.py b/egs/librispeech/SSL/hubert/model.py index 46a968b69..2c2077376 100644 --- a/egs/librispeech/SSL/hubert/model.py +++ b/egs/librispeech/SSL/hubert/model.py @@ -24,7 +24,7 @@ import torch import torch.nn as nn from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class AsrModel(nn.Module): @@ -221,7 +221,7 @@ class AsrModel(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -256,7 +256,7 @@ class AsrModel(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/SSL/hubert/pretrain.py b/egs/librispeech/SSL/hubert/pretrain.py index f183d90fd..ef2e3ad78 100644 --- a/egs/librispeech/SSL/hubert/pretrain.py +++ b/egs/librispeech/SSL/hubert/pretrain.py @@ -75,6 +75,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, @@ -644,7 +645,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/pretrain_ce.py b/egs/librispeech/SSL/hubert/pretrain_ce.py index 94948695d..12f95c16f 100644 --- a/egs/librispeech/SSL/hubert/pretrain_ce.py +++ b/egs/librispeech/SSL/hubert/pretrain_ce.py @@ -80,6 +80,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -644,7 +645,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/zipformer/finetune.py b/egs/librispeech/SSL/zipformer/finetune.py index c907b41c5..5bebf60f0 100644 --- a/egs/librispeech/SSL/zipformer/finetune.py +++ b/egs/librispeech/SSL/zipformer/finetune.py @@ -81,6 +81,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, @@ -1115,7 +1116,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1504,7 +1505,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/zipformer/model.py b/egs/librispeech/SSL/zipformer/model.py index 46a968b69..2c2077376 100644 --- a/egs/librispeech/SSL/zipformer/model.py +++ b/egs/librispeech/SSL/zipformer/model.py @@ -24,7 +24,7 @@ import torch import torch.nn as nn from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class AsrModel(nn.Module): @@ -221,7 +221,7 @@ class AsrModel(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -256,7 +256,7 @@ class AsrModel(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/SSL/zipformer/pretrain.py b/egs/librispeech/SSL/zipformer/pretrain.py index 937fb382e..d772f56d0 100644 --- a/egs/librispeech/SSL/zipformer/pretrain.py +++ b/egs/librispeech/SSL/zipformer/pretrain.py @@ -78,6 +78,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -944,7 +945,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1334,7 +1335,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/zipformer/zipformer.py b/egs/librispeech/SSL/zipformer/zipformer.py index e9eff3357..5071a91a8 100644 --- a/egs/librispeech/SSL/zipformer/zipformer.py +++ b/egs/librispeech/SSL/zipformer/zipformer.py @@ -22,6 +22,7 @@ import math import random import warnings from typing import List, Optional, Tuple, Union +from icefall.utils import torch_autocast import torch from encoder_interface import EncoderInterface @@ -1849,7 +1850,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/WSASR/conformer_ctc2/train.py b/egs/librispeech/WSASR/conformer_ctc2/train.py index 82c68803f..19cce1708 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/train.py +++ b/egs/librispeech/WSASR/conformer_ctc2/train.py @@ -84,6 +84,7 @@ from icefall.utils import ( get_texts, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -757,7 +758,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1076,7 +1077,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py index b276d0587..a32183bf7 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py +++ b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py @@ -79,6 +79,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.otc_phone_graph_compiler import OtcPhoneTrainingGraphCompiler from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, encode_supervisions_otc, @@ -758,7 +759,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1078,7 +1079,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index 0178b80bf..023afb5a5 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -53,7 +53,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) def get_parser(): @@ -401,7 +407,7 @@ def compute_validation_loss( for batch_idx, batch in enumerate(valid_dl): x, y, sentence_lengths = batch - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, @@ -470,7 +476,7 @@ def train_one_epoch( params.batch_idx_train += 1 x, y, sentence_lengths = batch batch_size = x.size(0) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, diff --git a/icefall/utils.py b/icefall/utils.py index ffb926566..405b6b327 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -26,6 +26,7 @@ import pathlib import random import re import subprocess +import warnings from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass @@ -42,14 +43,40 @@ import torch import torch.distributed as dist import torch.nn as nn from lhotse.dataset.signal_transforms import time_warp as time_warp_impl +from packaging import version from pypinyin import lazy_pinyin, pinyin from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials from torch.utils.tensorboard import SummaryWriter +from contextlib import contextmanager + from icefall.checkpoint import average_checkpoints Pathlike = Union[str, Path] +TORCH_VERSION = version.parse(torch.__version__) + + +@contextmanager +def torch_autocast(device_type="cuda", **kwargs): + """ + To fix the following warnings: + /icefall/egs/librispeech/ASR/zipformer/model.py:323: + FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. + Please use `torch.amp.autocast('cuda', args...)` instead. + with torch.cuda.amp.autocast(enabled=False): + """ + if TORCH_VERSION >= version.parse("2.0.0"): + # Use new unified API + with torch.amp.autocast(device_type=device_type, **kwargs): + yield + else: + # Suppress deprecation warning and use old CUDA-specific autocast + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + with torch.cuda.amp.autocast(**kwargs): + yield + # Pytorch issue: https://github.com/pytorch/pytorch/issues/47379 # Fixed: https://github.com/pytorch/pytorch/pull/49853 @@ -1551,6 +1578,7 @@ def optim_step_and_measure_param_change( and the L2 norm of the original parameter. It is given by the formula: .. math:: + \begin{aligned} \delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2} \end{aligned}