diff --git a/.github/workflows/aishell.yml b/.github/workflows/aishell.yml index 8b0599fca..4572c0c7f 100644 --- a/.github/workflows/aishell.yml +++ b/.github/workflows/aishell.yml @@ -17,7 +17,7 @@ concurrency: jobs: generate_build_matrix: - if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'aishell') + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' # see https://github.com/pytorch/pytorch/pull/50633 runs-on: ubuntu-latest @@ -31,8 +31,8 @@ jobs: id: set-matrix run: | # outputting for debugging purposes - python ./.github/scripts/docker/generate_build_matrix.py - MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10") echo "::set-output name=matrix::${MATRIX}" aishell: needs: generate_build_matrix diff --git a/.github/workflows/audioset.yml b/.github/workflows/audioset.yml index 9c9446239..90d418513 100644 --- a/.github/workflows/audioset.yml +++ b/.github/workflows/audioset.yml @@ -30,8 +30,8 @@ jobs: id: set-matrix run: | # outputting for debugging purposes - python ./.github/scripts/docker/generate_build_matrix.py - MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10") echo "::set-output name=matrix::${MATRIX}" audioset: @@ -83,7 +83,7 @@ jobs: ls -lh ./model-onnx/* - name: Upload model to huggingface - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push' env: HF_TOKEN: ${{ secrets.HF_TOKEN }} uses: nick-fields/retry@v3 @@ -116,7 +116,7 @@ jobs: rm -rf huggingface - name: Prepare for release - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push' shell: bash run: | d=sherpa-onnx-zipformer-audio-tagging-2024-04-09 @@ -125,7 +125,7 @@ jobs: ls -lh - name: Release exported onnx models - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push' uses: svenstaro/upload-release-action@v2 with: file_glob: true diff --git a/.github/workflows/baker_zh.yml b/.github/workflows/baker_zh.yml index 7805ab5ab..044919fd8 100644 --- a/.github/workflows/baker_zh.yml +++ b/.github/workflows/baker_zh.yml @@ -31,8 +31,8 @@ jobs: id: set-matrix run: | # outputting for debugging purposes - python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3" - MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3") + python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10") echo "::set-output name=matrix::${MATRIX}" baker_zh: @@ -84,43 +84,43 @@ jobs: ls -lh - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' with: name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }} path: ./*.wav - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' with: name: step-2 path: ./model-steps-2.onnx - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' with: name: step-3 path: ./model-steps-3.onnx - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' with: name: step-4 path: ./model-steps-4.onnx - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' with: name: step-5 path: ./model-steps-5.onnx - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' with: name: step-6 path: ./model-steps-6.onnx - name: Upload models to huggingface - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push' shell: bash env: HF_TOKEN: ${{ secrets.HF_TOKEN }} @@ -141,7 +141,7 @@ jobs: popd - name: Release exported onnx models - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push' uses: svenstaro/upload-release-action@v2 with: file_glob: true diff --git a/.github/workflows/librispeech.yml b/.github/workflows/librispeech.yml index 6e087b10a..19037c11b 100644 --- a/.github/workflows/librispeech.yml +++ b/.github/workflows/librispeech.yml @@ -29,8 +29,9 @@ jobs: id: set-matrix run: | # outputting for debugging purposes - python ./.github/scripts/docker/generate_build_matrix.py - MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" + # MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10") + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.6.0") echo "::set-output name=matrix::${MATRIX}" librispeech: needs: generate_build_matrix diff --git a/.github/workflows/ljspeech.yml b/.github/workflows/ljspeech.yml index 31a65cd94..52a3b1a3f 100644 --- a/.github/workflows/ljspeech.yml +++ b/.github/workflows/ljspeech.yml @@ -30,8 +30,8 @@ jobs: id: set-matrix run: | # outputting for debugging purposes - python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3" - MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3") + python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10") echo "::set-output name=matrix::${MATRIX}" ljspeech: @@ -83,13 +83,13 @@ jobs: ls -lh - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' with: name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }} path: ./*.wav - name: Release exported onnx models - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push' uses: svenstaro/upload-release-action@v2 with: file_glob: true @@ -100,37 +100,37 @@ jobs: tag: tts-models - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' with: name: step-2 path: ./model-steps-2.onnx - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' with: name: step-3 path: ./model-steps-3.onnx - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' with: name: step-4 path: ./model-steps-4.onnx - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' with: name: step-5 path: ./model-steps-5.onnx - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' with: name: step-6 path: ./model-steps-6.onnx - name: Upload models to huggingface - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' shell: bash env: HF_TOKEN: ${{ secrets.HF_TOKEN }} @@ -155,7 +155,7 @@ jobs: popd - name: Release exported onnx models - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' + if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' uses: svenstaro/upload-release-action@v2 with: file_glob: true diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c22f2edb5..ed0e62330 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -30,8 +30,8 @@ jobs: id: set-matrix run: | # outputting for debugging purposes - python ./.github/scripts/docker/generate_build_matrix.py - MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10") echo "::set-output name=matrix::${MATRIX}" test: needs: generate_build_matrix diff --git a/.github/workflows/yesno.yml b/.github/workflows/yesno.yml index a9d65516f..a5832df9d 100644 --- a/.github/workflows/yesno.yml +++ b/.github/workflows/yesno.yml @@ -30,8 +30,9 @@ jobs: id: set-matrix run: | # outputting for debugging purposes - python ./.github/scripts/docker/generate_build_matrix.py - MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10") + # MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.5.0") echo "::set-output name=matrix::${MATRIX}" yesno: needs: generate_build_matrix diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py index fa809b768..9060cdb26 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py @@ -79,7 +79,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -638,7 +644,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -912,7 +918,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py index 60f014c48..457b564fe 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py @@ -72,7 +72,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -688,7 +694,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -989,7 +995,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/model.py b/egs/aishell/ASR/pruned_transducer_stateless3/model.py index a4dda0d6d..3b9dad55e 100644 --- a/egs/aishell/ASR/pruned_transducer_stateless3/model.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -184,7 +184,7 @@ class Transducer(nn.Module): lm = simple_lm_proj(decoder_out) am = simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -219,7 +219,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py index 7c23041ca..ad9f40e25 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py @@ -94,7 +94,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -797,7 +803,7 @@ def train_one_epoch( aishell = is_aishell(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1202,7 +1208,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train.py b/egs/aishell/ASR/pruned_transducer_stateless7/train.py index 2dc835f3b..85a51278b 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/train.py @@ -94,6 +94,7 @@ from icefall.utils import ( filter_uneven_sized_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -809,7 +810,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py index 811269989..a07216de8 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -87,6 +87,7 @@ from icefall.utils import ( setup_logger, str2bool, tokenize_by_CJK_char, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -802,7 +803,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1202,7 +1203,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py index f3b0f1e11..a8373d755 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py @@ -81,7 +81,13 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -812,7 +818,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1202,7 +1208,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index d77f8c270..af4d6442e 100755 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -81,6 +81,7 @@ from icefall.utils import ( filter_uneven_sized_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -514,7 +515,7 @@ def compute_validation_loss( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -608,7 +609,7 @@ def train_one_epoch( ) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py index dddfe52fa..0c389db55 100755 --- a/egs/aishell/ASR/zipformer/train.py +++ b/egs/aishell/ASR/zipformer/train.py @@ -95,6 +95,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -910,7 +911,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1302,7 +1303,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/zipformer/train_bbpe.py b/egs/aishell/ASR/zipformer/train_bbpe.py index dbc262c5c..b9d7fe8ad 100755 --- a/egs/aishell/ASR/zipformer/train_bbpe.py +++ b/egs/aishell/ASR/zipformer/train_bbpe.py @@ -92,6 +92,7 @@ from icefall.utils import ( setup_logger, str2bool, tokenize_by_CJK_char, + torch_autocast, ) @@ -495,7 +496,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -895,7 +896,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py index 8c7448d4c..84cd2ffca 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py @@ -90,7 +90,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -734,7 +740,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1062,7 +1068,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py index a354f761e..ab97f8677 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py @@ -83,7 +83,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -727,7 +733,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) # print(batch["supervisions"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1034,7 +1040,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py index 30154291d..172d94862 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py @@ -79,7 +79,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -638,7 +644,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -912,7 +918,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py index 30879d8d2..855aeca12 100755 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py @@ -73,7 +73,13 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -782,7 +788,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1127,7 +1133,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py index d62cdadb7..8922717ef 100755 --- a/egs/ami/ASR/pruned_transducer_stateless7/train.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py @@ -71,7 +71,13 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -773,7 +779,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1134,7 +1140,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ami/SURT/dprnn_zipformer/train.py b/egs/ami/SURT/dprnn_zipformer/train.py index adc6a8495..3572acd04 100755 --- a/egs/ami/SURT/dprnn_zipformer/train.py +++ b/egs/ami/SURT/dprnn_zipformer/train.py @@ -76,7 +76,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -1067,7 +1073,7 @@ def train_one_epoch( batch_size = batch["inputs"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, diff --git a/egs/ami/SURT/dprnn_zipformer/train_adapt.py b/egs/ami/SURT/dprnn_zipformer/train_adapt.py index ac5b0dadc..313a5c46a 100755 --- a/egs/ami/SURT/dprnn_zipformer/train_adapt.py +++ b/egs/ami/SURT/dprnn_zipformer/train_adapt.py @@ -76,7 +76,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -1058,7 +1064,7 @@ def train_one_epoch( batch_size = batch["inputs"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, diff --git a/egs/audioset/AT/zipformer/train.py b/egs/audioset/AT/zipformer/train.py index 67c703364..caf8accb2 100644 --- a/egs/audioset/AT/zipformer/train.py +++ b/egs/audioset/AT/zipformer/train.py @@ -74,6 +74,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -799,7 +800,7 @@ def train_one_epoch( num_samples += batch_size try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1148,7 +1149,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py b/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py index 0720158f2..deb344d14 100755 --- a/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py +++ b/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py @@ -73,6 +73,8 @@ def compute_fbank_baker_zh(num_jobs: int): f_min=0, f_max=8000, ) + if not torch.cuda.is_available(): + config.device = "cpu" prefix = "baker_zh" suffix = "jsonl.gz" diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py index 5e98084ec..7a859ff38 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py @@ -88,6 +88,7 @@ from icefall.utils import ( filter_uneven_sized_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -825,7 +826,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1220,7 +1221,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py index 976004eca..fb812b391 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py @@ -90,6 +90,7 @@ from icefall.utils import ( filter_uneven_sized_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -895,7 +896,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1293,7 +1294,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py index 67e1a8133..f1e9b6d43 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py @@ -81,7 +81,13 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -840,7 +846,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1237,7 +1243,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/zipformer/train.py b/egs/commonvoice/ASR/zipformer/train.py index 271014db0..c6940def5 100755 --- a/egs/commonvoice/ASR/zipformer/train.py +++ b/egs/commonvoice/ASR/zipformer/train.py @@ -97,6 +97,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -969,7 +970,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1365,7 +1366,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/zipformer/train_char.py b/egs/commonvoice/ASR/zipformer/train_char.py index 0aa7856cc..f44232c0e 100755 --- a/egs/commonvoice/ASR/zipformer/train_char.py +++ b/egs/commonvoice/ASR/zipformer/train_char.py @@ -97,6 +97,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -604,7 +605,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -784,7 +785,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py index ef7ea9013..5862cd660 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py @@ -83,7 +83,13 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LOG_EPS = math.log(1e-10) @@ -838,7 +844,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1245,7 +1251,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index a7772b62f..56371e59a 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -77,7 +77,13 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -675,7 +681,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -944,7 +950,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py index 4c122effe..8cf8f9fc7 100755 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -97,6 +97,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -958,7 +959,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1317,7 +1318,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py index 39d8fc6cd..2d88b6e55 100755 --- a/egs/gigaspeech/KWS/zipformer/train.py +++ b/egs/gigaspeech/KWS/zipformer/train.py @@ -97,6 +97,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -961,7 +962,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1320,7 +1321,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py index bf50bf5ea..63a38a4cc 100755 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -77,7 +77,13 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -805,7 +811,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1196,7 +1202,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ksponspeech/ASR/zipformer/train.py b/egs/ksponspeech/ASR/zipformer/train.py index 485ea69c9..406749f22 100755 --- a/egs/ksponspeech/ASR/zipformer/train.py +++ b/egs/ksponspeech/ASR/zipformer/train.py @@ -92,6 +92,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -942,7 +943,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1333,7 +1334,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 7e0bf5b7b..fc866f83b 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -667,7 +667,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -707,7 +709,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.method in [ diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 38b60fcb9..5b3a021ad 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -271,7 +271,7 @@ def main(): use_feat_batchnorm=params.use_feat_batchnorm, ) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -351,7 +351,9 @@ def main(): "attention-decoder", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -362,7 +364,9 @@ def main(): "attention-decoder", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = G.to(device) diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 0b271a51c..349e8f02d 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -774,7 +774,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -814,7 +816,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.method in [ diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index c4a13b101..14c132ada 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -65,7 +65,6 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -84,9 +83,11 @@ from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -420,7 +421,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -629,7 +630,7 @@ def train_one_epoch( scheduler: LRSchedulerType, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -676,7 +677,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -965,7 +966,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index e6327bb5e..cf58fd18d 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -868,7 +868,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -907,7 +909,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py index 19b26361e..f8e3fa43b 100755 --- a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py @@ -334,7 +334,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -345,7 +347,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/conformer_ctc3/pretrained.py b/egs/librispeech/ASR/conformer_ctc3/pretrained.py index a0cdfcf03..e528b2cb8 100755 --- a/egs/librispeech/ASR/conformer_ctc3/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc3/pretrained.py @@ -290,7 +290,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -386,7 +386,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -397,7 +399,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py index a2f1125ca..64e77f421 100755 --- a/egs/librispeech/ASR/conformer_ctc3/train.py +++ b/egs/librispeech/ASR/conformer_ctc3/train.py @@ -76,7 +76,6 @@ from lhotse.utils import fix_random_seed from model import CTCModel from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -95,9 +94,11 @@ from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -493,7 +494,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -694,7 +695,7 @@ def train_one_epoch( graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler], train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -743,7 +744,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1004,7 +1005,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1073,7 +1074,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index 74f6e73fa..01fcf0685 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -574,7 +574,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + torch.load( + f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False + ) ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -609,7 +611,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False + ) G = k2.Fsa.from_dict(d).to(device) if params.method in ["whole-lattice-rescoring", "attention-decoder"]: diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index ca21bd6bf..fc33f9512 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -93,7 +92,14 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -560,7 +566,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -727,7 +733,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -772,7 +778,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1002,7 +1008,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1071,7 +1077,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 23ddb6bec..b00cc6cc6 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -93,7 +92,14 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -560,7 +566,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -727,7 +733,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -772,7 +778,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1001,7 +1007,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1072,7 +1078,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index d19d50ae6..ec39d5b36 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -72,11 +72,11 @@ def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: max_token_id = max(lexicon.tokens) logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False)) if Path(f"data/lm/{lm}.pt").is_file(): logging.info(f"Loading pre-compiled {lm}") - d = torch.load(f"data/lm/{lm}.pt") + d = torch.load(f"data/lm/{lm}.pt", weights_only=False) G = k2.Fsa.from_dict(d) else: logging.info(f"Loading {lm}.fst.txt") diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index 709b14070..bd25cfa29 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -66,11 +66,11 @@ def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: An FSA representing LG. """ lexicon = Lexicon(lang_dir) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False)) if Path(f"data/lm/{lm}.pt").is_file(): logging.info(f"Loading pre-compiled {lm}") - d = torch.load(f"data/lm/{lm}.pt") + d = torch.load(f"data/lm/{lm}.pt", weights_only=False) G = k2.Fsa.from_dict(d) else: logging.info(f"Loading {lm}.fst.txt") diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py index 856c9d945..8c75eb871 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -750,7 +750,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index e7bad7ed8..9f148b348 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -156,7 +156,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -192,7 +192,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py index 42c3a5d7f..f29d1d9db 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -238,7 +238,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index feb81d500..e23da3b56 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -66,7 +66,6 @@ from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -82,9 +81,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -521,7 +522,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -717,7 +718,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -763,7 +764,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1023,7 +1024,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1092,7 +1093,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 1a724830b..cfbbb334c 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -935,7 +935,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index 4957d14b1..5aafe10af 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -195,7 +195,7 @@ class Transducer(nn.Module): lm = simple_lm_proj(decoder_out) am = simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -231,7 +231,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index dcff088e2..888f9931e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -241,7 +241,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 4fc4fa7f8..1b31b5485 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -74,7 +74,6 @@ from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -90,9 +89,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -560,7 +561,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -772,7 +773,7 @@ def train_one_epoch( giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, rng: random.Random, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -848,7 +849,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1176,7 +1177,7 @@ def run(rank, world_size, args): else: logging.info("Skip scan_pessimistic_batches_for_oom") - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1247,7 +1248,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index a2b4f9e1a..e25b79e2e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -815,7 +815,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index e39637bd8..619e783b0 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -239,7 +239,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index 2c1cef3a3..e169b499f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -66,7 +66,6 @@ from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -82,9 +81,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -551,7 +552,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -747,7 +748,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -793,7 +794,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1067,7 +1068,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1136,7 +1137,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py index ca8c28af1..3b6ce9b89 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/model.py +++ b/egs/librispeech/ASR/pruned2_knowledge/model.py @@ -21,7 +21,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -141,7 +141,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -176,7 +176,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 5b595c76c..5850555cd 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -10,9 +10,11 @@ from typing import Optional, Tuple import torch from scaling import ScaledLinear from torch import Tensor, nn -from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd +from torch.cuda.amp import custom_bwd, custom_fwd from torch_scheduled_sampling import sample_combined +from icefall.utils import create_grad_scaler, torch_autocast + # The main exports of this file are the module KnowledgeBaseLookup and the # function create_knowledge_base. @@ -330,14 +332,14 @@ def _test_knowledge_base_lookup_autocast(): optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device) - scaler = GradScaler(enabled=True) + scaler = create_grad_scaler(enabled=True) start = timeit.default_timer() for epoch in range(150): for n, (x, y) in enumerate(train_pairs): y_out = m(x) - with torch.cuda.amp.autocast(enabled=True): + with torch_autocast(enabled=True): loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 931341cc4..0611fd8cb 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -76,7 +75,14 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + create_grad_scaler, + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -453,7 +459,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -608,7 +614,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -650,7 +656,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -868,7 +874,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -937,7 +943,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index 2b872f1d5..2af8f3f4c 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -55,7 +55,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from noam import Noam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -68,7 +67,14 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) def add_model_arguments(parser: argparse.ArgumentParser): @@ -496,7 +502,7 @@ def save_checkpoint( model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, and training stats to file. @@ -650,7 +656,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -693,7 +699,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -939,7 +945,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1004,7 +1010,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 3c4500087..6d1da7440 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -741,7 +741,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index c57514193..5a4a74ebb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -754,7 +754,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 272d06c37..6a69332aa 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -157,7 +157,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -193,7 +193,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index 6923f4d40..e6ddcab25 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -265,7 +265,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 6c19f2cb0..ce6c89614 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -78,7 +78,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -91,9 +90,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -523,7 +524,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -716,7 +717,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -759,7 +760,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1000,7 +1001,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 0 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1067,7 +1068,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 7c62bfa58..18a3792b0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -921,7 +921,7 @@ def load_ngram_LM( if pt_file.is_file(): logging.info(f"Loading pre-compiled {pt_file}") - d = torch.load(pt_file, map_location=device) + d = torch.load(pt_file, map_location=device, weights_only=False) G = k2.Fsa.from_dict(d) G = k2.add_epsilon_self_loops(G) G = k2.arc_sort(G) @@ -1101,7 +1101,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale elif params.decoding_method in [ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index d45f6dadc..fbc4db921 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -195,7 +195,7 @@ class Transducer(nn.Module): lm = simple_lm_proj(decoder_out) am = simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -231,7 +231,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 05e6a6fba..19143fb5d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -274,7 +274,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index fdafa5a87..50670d1b2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -74,7 +74,6 @@ from librispeech import LibriSpeech from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -87,9 +86,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -546,7 +547,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -755,7 +756,7 @@ def train_one_epoch( giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, rng: random.Random, - scaler: GradScaler, + scaler: "GradScaler", tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -827,7 +828,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1126,7 +1127,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 0 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1195,7 +1196,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 5195a4ef6..925c01c7b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -913,7 +913,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 875b03f7f..c35f52309 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -96,9 +95,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -548,7 +549,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -744,7 +745,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -789,7 +790,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1047,7 +1048,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1116,7 +1117,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 7a3e63218..404d7a3d3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -972,7 +972,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index a9ce75a7b..9e2669379 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -238,7 +238,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 66dc5f991..6f9f92623 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -68,7 +68,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -84,9 +83,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -571,7 +572,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -768,7 +769,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -814,7 +815,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1078,7 +1079,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1147,7 +1148,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index daadb70c9..a5d2457f9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -185,7 +185,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -220,7 +220,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index 8f033cb9a..35ee74f15 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -96,9 +95,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -519,7 +520,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -736,7 +737,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -781,7 +782,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1039,7 +1040,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1108,7 +1109,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 3bca7db2c..4f3fbaa81 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -348,7 +348,9 @@ class CodebookIndexExtractor: num_codebooks=self.params.num_codebooks, codebook_size=256, ) - quantizer.load_state_dict(torch.load(self.quantizer_file_path)) + quantizer.load_state_dict( + torch.load(self.quantizer_file_path, weights_only=False) + ) quantizer.to(self.params.device) return quantizer diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py index 27ef0a244..949a497ce 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py @@ -289,7 +289,7 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index eb8841cc4..048de7bb9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -910,7 +910,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py index 7095c3cc8..da1bf17fc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py @@ -813,7 +813,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index e7546ec45..d3d996b4a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -85,9 +84,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, filter_uneven_sized_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -635,7 +636,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: @@ -678,7 +679,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -857,7 +858,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -903,7 +904,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1219,7 +1220,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1319,7 +1320,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index add0e6a18..ed990b689 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import penalize_abs_values_gt -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -150,7 +150,7 @@ class Transducer(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -185,7 +185,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index 4bf11ac24..fabda3aaa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 30a737061..5a317083c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -28,6 +28,8 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Embedding as ScaledEmbedding +from icefall.utils import torch_autocast + class ActivationBalancerFunction(torch.autograd.Function): @staticmethod @@ -289,7 +291,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -669,7 +671,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): def backward(ctx, x_grad: Tensor): (x_orig,) = ctx.saved_tensors with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -867,7 +869,7 @@ class MaxEig(torch.nn.Module): ): return _no_op(x) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): eps = 1.0e-20 orig_x = x x = x.to(torch.float32) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 436ec53b4..f94da9788 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -86,10 +85,12 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, filter_uneven_sized_batch, setup_logger, str2bool, symlink_or_copy, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -581,7 +582,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -763,7 +764,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -809,7 +810,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1106,7 +1107,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index cbde2a2e4..ee05627ae 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -44,7 +44,7 @@ from scaling import ( from torch import Tensor, nn from icefall.dist import get_rank -from icefall.utils import is_jit_tracing, make_pad_mask +from icefall.utils import is_jit_tracing, make_pad_mask, torch_autocast class Zipformer(EncoderInterface): @@ -1421,7 +1421,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz = n // num_heads with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) attn_weights_entropy = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py index 629bec058..3b181bf23 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py @@ -633,7 +633,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -672,7 +674,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py index 7641fa5af..9e16c3fd7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py @@ -786,7 +786,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py index d1b7eec65..f7dd07f8d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py @@ -347,7 +347,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -358,7 +360,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py index a6e919e2f..f1ab2a3ec 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py @@ -22,7 +22,7 @@ import torch import torch.nn as nn from encoder_interface import EncoderInterface -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -150,7 +150,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -185,7 +185,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py index 323ba2642..a13952dfa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py index 1e638aa7d..32242c94e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py @@ -286,7 +286,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -365,7 +365,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -376,7 +378,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index b35e56abc..a26f11c82 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -86,9 +85,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -588,7 +589,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -787,7 +788,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -833,7 +834,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1128,7 +1129,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1228,7 +1229,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py index fa7144f0f..3af3ada2c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py @@ -624,7 +624,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -663,7 +665,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py index e2f08abc6..233f00236 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py @@ -808,7 +808,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py index e497787d3..025b146b9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py @@ -786,7 +786,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py index 80604ef4a..70d9841bf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py @@ -347,7 +347,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -358,7 +360,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py index 0582b289f..bf0faf9f1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py @@ -22,7 +22,7 @@ import torch import torch.nn as nn from encoder_interface import EncoderInterface -from icefall.utils import add_sos, make_pad_mask +from icefall.utils import add_sos, make_pad_mask, torch_autocast class Transducer(nn.Module): @@ -178,7 +178,7 @@ class Transducer(nn.Module): am = self.simple_am_proj(encoder_out_fr) lm = self.simple_lm_proj(decoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -213,7 +213,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py index a82f3562b..9ceec5f5a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py index b98756a54..431760f9a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py @@ -286,7 +286,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -362,7 +362,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -373,7 +375,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index c2d877a93..5585d74de 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -63,7 +63,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -82,9 +81,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -581,7 +582,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -778,7 +779,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -822,7 +823,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1118,7 +1119,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1217,7 +1218,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py index 02029c108..aa2fe8e38 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -936,7 +936,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py index aa2dd17fb..f98851f50 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index 8bd00bbef..4d8a2644d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -82,7 +81,14 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -597,7 +603,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -764,7 +770,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -810,7 +816,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1124,7 +1130,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1224,7 +1230,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index c7e45564f..640d72b67 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -44,7 +44,7 @@ from scaling import ( ) from torch import Tensor, nn -from icefall.utils import make_pad_mask, subsequent_chunk_mask +from icefall.utils import make_pad_mask, subsequent_chunk_mask, torch_autocast def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]: @@ -2408,7 +2408,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz = n // num_heads with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) attn_weights_entropy = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py index 35158ced4..61c1a9663 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py @@ -768,7 +768,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py index a4f52ad7f..e95bb3357 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py @@ -788,7 +788,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py index da5e144c9..4b97575e6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py @@ -70,7 +70,6 @@ from librispeech import LibriSpeech from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -86,7 +85,14 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -615,7 +621,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -795,7 +801,7 @@ def train_one_epoch( giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, rng: random.Random, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -866,7 +872,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1218,7 +1224,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1320,7 +1326,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index e07777c9f..3cad83a0b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -747,7 +747,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py index 39a360796..e06594c27 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -24,7 +24,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import penalize_abs_values_gt -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -172,7 +172,7 @@ class Transducer(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -207,7 +207,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index c29b8d8c9..693db2beb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 646f30ca1..ad14ec9dc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -75,7 +75,6 @@ from librispeech import LibriSpeech from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -91,7 +90,14 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -608,7 +614,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -790,7 +796,7 @@ def train_one_epoch( giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, rng: random.Random, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -866,7 +872,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1219,7 +1225,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1321,7 +1327,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 92529e06c..db12ab827 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -398,7 +398,9 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -428,7 +430,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False + ) G = k2.Fsa.from_dict(d).to(device) if params.method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py index b3dfab64a..4ad7cb016 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py @@ -167,13 +167,15 @@ def main(): subsampling_factor=params.subsampling_factor, ) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"]) model.to(device) model.eval() logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -181,7 +183,9 @@ def main(): if params.method == "whole-lattice-rescoring": logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = G.to(device) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py index cda03b56e..ec700626a 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py @@ -589,7 +589,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -628,7 +630,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py index cc4471e2b..1b329e8f3 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py @@ -663,7 +663,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py index 92dea3aa1..4b234a328 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py @@ -347,7 +347,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -358,7 +360,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py index 5c6956324..9714aa537 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py @@ -249,7 +249,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py index 7698ada79..a2ea1dd06 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py @@ -286,7 +286,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -365,7 +365,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -376,7 +378,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py index 1bfd071de..368bd20fa 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/train.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py @@ -51,7 +51,6 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import AdamW from torch.optim.lr_scheduler import StepLR @@ -72,9 +71,11 @@ from icefall.lexicon import UniqLexicon from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = torch.optim.lr_scheduler._LRScheduler @@ -550,7 +551,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -757,7 +758,7 @@ def train_one_epoch( phone_lexicon: UniqLexicon, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -809,7 +810,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1092,7 +1093,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1198,7 +1199,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py index 4d9bbf4b1..06b1c05b9 100755 --- a/egs/librispeech/ASR/transducer/pretrained.py +++ b/egs/librispeech/ASR/transducer/pretrained.py @@ -222,7 +222,7 @@ def main(): logging.info("Creating model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index fe9347b95..e17407c5f 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -947,7 +947,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -987,7 +989,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method in [ diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index cbfb3728e..6462d22f8 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -1013,7 +1013,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/zipformer/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer/decode_gigaspeech.py index 3cda337c0..a4da83949 100755 --- a/egs/librispeech/ASR/zipformer/decode_gigaspeech.py +++ b/egs/librispeech/ASR/zipformer/decode_gigaspeech.py @@ -1049,7 +1049,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py index 2c869a57a..94e8b273a 100755 --- a/egs/librispeech/ASR/zipformer/finetune.py +++ b/egs/librispeech/ASR/zipformer/finetune.py @@ -78,7 +78,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -95,11 +94,13 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( + create_grad_scaler, AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -765,7 +766,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: @@ -808,7 +809,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -985,7 +986,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dls: torch.utils.data.DataLoader, valid_sets: List[str], - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1049,7 +1050,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1373,7 +1374,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1474,7 +1475,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py index fcd07ae34..d1978df52 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py @@ -346,7 +346,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -357,7 +359,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index f2791e51f..6ef250819 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -25,7 +25,7 @@ from encoder_interface import EncoderInterface from lhotse.dataset import SpecAugment from scaling import ScaledLinear -from icefall.utils import add_sos, make_pad_mask, time_warp +from icefall.utils import add_sos, make_pad_mask, time_warp, torch_autocast class AsrModel(nn.Module): @@ -285,7 +285,7 @@ class AsrModel(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -320,7 +320,7 @@ class AsrModel(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/zipformer/pretrained.py b/egs/librispeech/ASR/zipformer/pretrained.py index 9f3571b08..65ea7c7f2 100755 --- a/egs/librispeech/ASR/zipformer/pretrained.py +++ b/egs/librispeech/ASR/zipformer/pretrained.py @@ -289,7 +289,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/zipformer/pretrained_ctc.py b/egs/librispeech/ASR/zipformer/pretrained_ctc.py index 4341ef61f..90a6ff5b8 100755 --- a/egs/librispeech/ASR/zipformer/pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/pretrained_ctc.py @@ -305,7 +305,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -389,7 +389,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -400,7 +402,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 11375385e..22aa1b1ca 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -26,6 +26,8 @@ import torch.nn as nn from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd +from icefall.utils import torch_autocast + def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: max_value = torch.max(x, y) @@ -308,7 +310,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -761,7 +763,7 @@ class BalancerFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x = x.to(torch.float32) x = x.detach() x.requires_grad = True @@ -1016,7 +1018,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -1355,7 +1357,7 @@ class SwooshLFunction(torch.autograd.Function): coeff = -0.08 - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True @@ -1432,7 +1434,7 @@ class SwooshRFunction(torch.autograd.Function): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index f8864d58b..42ae9b9f2 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -79,7 +79,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -98,9 +97,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -829,7 +830,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -1034,7 +1035,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", spec_augment: Optional[SpecAugment] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, @@ -1101,9 +1102,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast( - enabled=params.use_autocast, dtype=params.dtype - ): + with torch_autocast(enabled=params.use_autocast, dtype=params.dtype): loss, loss_info = compute_loss( params=params, model=model, @@ -1449,7 +1448,7 @@ def run(rank, world_size, args): spec_augment=spec_augment, ) - scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_autocast, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1551,9 +1550,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast( - enabled=params.use_autocast, dtype=params.dtype - ): + with torch_autocast(enabled=params.use_autocast, dtype=params.dtype): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2a0ae0129..e83a89400 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -47,6 +47,8 @@ from scaling import ( ) from torch import Tensor, nn +from icefall.utils import torch_autocast + class Zipformer2(EncoderInterface): """ @@ -1873,7 +1875,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/ASR/zipformer_adapter/decode.py b/egs/librispeech/ASR/zipformer_adapter/decode.py index 91533be8d..e8798aed6 100755 --- a/egs/librispeech/ASR/zipformer_adapter/decode.py +++ b/egs/librispeech/ASR/zipformer_adapter/decode.py @@ -1005,7 +1005,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py index bbc582f50..66c401761 100755 --- a/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py +++ b/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py @@ -1050,7 +1050,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py index 3511590da..fcd7272e9 100755 --- a/egs/librispeech/ASR/zipformer_adapter/train.py +++ b/egs/librispeech/ASR/zipformer_adapter/train.py @@ -67,7 +67,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -86,9 +85,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -762,7 +763,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: @@ -805,7 +806,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -982,7 +983,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dls: torch.utils.data.DataLoader, valid_sets: List[str], - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1052,7 +1053,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1397,7 +1398,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1498,7 +1499,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_adapter/zipformer.py b/egs/librispeech/ASR/zipformer_adapter/zipformer.py index 8e2dfdd72..8bc163db5 100644 --- a/egs/librispeech/ASR/zipformer_adapter/zipformer.py +++ b/egs/librispeech/ASR/zipformer_adapter/zipformer.py @@ -50,6 +50,8 @@ from scaling import ( ) from torch import Tensor, nn +from icefall.utils import torch_autocast + class Zipformer2(EncoderInterface): """ @@ -1916,7 +1918,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/ASR/zipformer_ctc/decode.py b/egs/librispeech/ASR/zipformer_ctc/decode.py index 7f605e2c8..b9eed099c 100755 --- a/egs/librispeech/ASR/zipformer_ctc/decode.py +++ b/egs/librispeech/ASR/zipformer_ctc/decode.py @@ -679,7 +679,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -719,7 +721,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.method in [ diff --git a/egs/librispeech/ASR/zipformer_ctc/train.py b/egs/librispeech/ASR/zipformer_ctc/train.py index 60112a84e..bd3bfa332 100755 --- a/egs/librispeech/ASR/zipformer_ctc/train.py +++ b/egs/librispeech/ASR/zipformer_ctc/train.py @@ -46,7 +46,6 @@ from lhotse.utils import fix_random_seed from model import CTCModel from optim import Eden, LRScheduler, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter @@ -65,7 +64,14 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler] @@ -533,7 +539,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -687,7 +693,7 @@ def train_one_epoch( graph_compiler: BpeCtcTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -726,7 +732,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -987,7 +993,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py index 4d93a905f..acc814a00 100755 --- a/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py +++ b/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py @@ -1050,7 +1050,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/zipformer_lora/finetune.py b/egs/librispeech/ASR/zipformer_lora/finetune.py index 3f36f229f..c26a2f5cc 100755 --- a/egs/librispeech/ASR/zipformer_lora/finetune.py +++ b/egs/librispeech/ASR/zipformer_lora/finetune.py @@ -78,7 +78,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -96,9 +95,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -775,7 +776,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: @@ -818,7 +819,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -995,7 +996,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dls: torch.utils.data.DataLoader, valid_sets: List[str], - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1065,7 +1066,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1406,7 +1407,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1507,7 +1508,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py index 8d7aa8027..1347570df 100644 --- a/egs/librispeech/ASR/zipformer_lora/scaling.py +++ b/egs/librispeech/ASR/zipformer_lora/scaling.py @@ -27,6 +27,8 @@ import torch.nn.functional as F from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd +from icefall.utils import torch_autocast + def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: max_value = torch.max(x, y) @@ -307,7 +309,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -863,7 +865,7 @@ class BalancerFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x = x.to(torch.float32) x = x.detach() x.requires_grad = True @@ -1118,7 +1120,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -1457,7 +1459,7 @@ class SwooshLFunction(torch.autograd.Function): coeff = -0.08 - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True @@ -1534,7 +1536,7 @@ class SwooshRFunction(torch.autograd.Function): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True diff --git a/egs/librispeech/ASR/zipformer_lora/train.py b/egs/librispeech/ASR/zipformer_lora/train.py index 9ab214e86..2b83d58ef 100755 --- a/egs/librispeech/ASR/zipformer_lora/train.py +++ b/egs/librispeech/ASR/zipformer_lora/train.py @@ -76,7 +76,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -94,9 +93,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -707,7 +708,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -883,7 +884,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -947,7 +948,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1252,7 +1253,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1352,7 +1353,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_lora/zipformer.py b/egs/librispeech/ASR/zipformer_lora/zipformer.py index 43865609a..b84b1c32a 100644 --- a/egs/librispeech/ASR/zipformer_lora/zipformer.py +++ b/egs/librispeech/ASR/zipformer_lora/zipformer.py @@ -49,6 +49,8 @@ from scaling import ( ) from torch import Tensor, nn +from icefall.utils import torch_autocast + class Zipformer2(EncoderInterface): """ @@ -1905,7 +1907,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/ASR/zipformer_mmi/decode.py b/egs/librispeech/ASR/zipformer_mmi/decode.py index 33c0bf199..bd3ce21f5 100755 --- a/egs/librispeech/ASR/zipformer_mmi/decode.py +++ b/egs/librispeech/ASR/zipformer_mmi/decode.py @@ -569,7 +569,9 @@ def main(): if params.decoding_method == "nbest-rescoring-LG": lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") - LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device)) + LG = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device, weights_only=False) + ) LG = k2.Fsa.from_fsas([LG]).to(device) LG.lm_scores = LG.scores.clone() @@ -602,7 +604,11 @@ def main(): torch.save(G.as_dict(), params.lang_dir / f"{order}gram.pt") else: logging.info(f"Loading pre-compiled {order}gram.pt") - d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device) + d = torch.load( + params.lang_dir / f"{order}gram.pt", + map_location=device, + weights_only=False, + ) G = k2.Fsa.from_dict(d) G.lm_scores = G.scores.clone() diff --git a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py index 6990c90a0..d5667cafa 100755 --- a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py +++ b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py @@ -308,7 +308,9 @@ def main(): if method == "nbest-rescoring-LG": lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") - LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device)) + LG = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device, weights_only=False) + ) LG = k2.Fsa.from_fsas([LG]).to(device) LG.lm_scores = LG.scores.clone() LM = LG @@ -317,7 +319,9 @@ def main(): assert order in ("3", "4") order = int(order) logging.info(f"Loading pre-compiled {order}gram.pt") - d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device) + d = torch.load( + params.lang_dir / f"{order}gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) G.lm_scores = G.scores.clone() LM = G diff --git a/egs/librispeech/ASR/zipformer_mmi/pretrained.py b/egs/librispeech/ASR/zipformer_mmi/pretrained.py index 1e7afc777..ca860b877 100755 --- a/egs/librispeech/ASR/zipformer_mmi/pretrained.py +++ b/egs/librispeech/ASR/zipformer_mmi/pretrained.py @@ -269,7 +269,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -331,7 +331,9 @@ def main(): if method == "nbest-rescoring-LG": lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") - LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device)) + LG = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device, weights_only=False) + ) LG = k2.Fsa.from_fsas([LG]).to(device) LG.lm_scores = LG.scores.clone() LM = LG @@ -340,7 +342,9 @@ def main(): assert order in ("3", "4") order = int(order) logging.info(f"Loading pre-compiled {order}gram.pt") - d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device) + d = torch.load( + params.lang_dir / f"{order}gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) G.lm_scores = G.scores.clone() LM = G diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py index c1785a328..e0ca0a6a5 100755 --- a/egs/librispeech/ASR/zipformer_mmi/train.py +++ b/egs/librispeech/ASR/zipformer_mmi/train.py @@ -64,7 +64,6 @@ from lhotse.utils import fix_random_seed from model import CTCModel from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -87,9 +86,11 @@ from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -514,7 +515,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -696,7 +697,7 @@ def train_one_epoch( mmi_graph_compiler: MmiTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -744,7 +745,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1037,7 +1038,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1138,7 +1139,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/finetune.py b/egs/librispeech/SSL/hubert/finetune.py index 17daa3c9d..0080513f3 100644 --- a/egs/librispeech/SSL/hubert/finetune.py +++ b/egs/librispeech/SSL/hubert/finetune.py @@ -86,6 +86,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -816,7 +817,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1207,7 +1208,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/finetune_ce.py b/egs/librispeech/SSL/hubert/finetune_ce.py index 2723cc770..1ff2b03c0 100644 --- a/egs/librispeech/SSL/hubert/finetune_ce.py +++ b/egs/librispeech/SSL/hubert/finetune_ce.py @@ -81,6 +81,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, @@ -816,7 +817,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1207,7 +1208,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/model.py b/egs/librispeech/SSL/hubert/model.py index 46a968b69..2c2077376 100644 --- a/egs/librispeech/SSL/hubert/model.py +++ b/egs/librispeech/SSL/hubert/model.py @@ -24,7 +24,7 @@ import torch import torch.nn as nn from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class AsrModel(nn.Module): @@ -221,7 +221,7 @@ class AsrModel(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -256,7 +256,7 @@ class AsrModel(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/SSL/hubert/pretrain.py b/egs/librispeech/SSL/hubert/pretrain.py index f183d90fd..240cd2c0d 100644 --- a/egs/librispeech/SSL/hubert/pretrain.py +++ b/egs/librispeech/SSL/hubert/pretrain.py @@ -80,6 +80,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -644,7 +645,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/pretrain_ce.py b/egs/librispeech/SSL/hubert/pretrain_ce.py index 94948695d..12f95c16f 100644 --- a/egs/librispeech/SSL/hubert/pretrain_ce.py +++ b/egs/librispeech/SSL/hubert/pretrain_ce.py @@ -80,6 +80,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -644,7 +645,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/zipformer/finetune.py b/egs/librispeech/SSL/zipformer/finetune.py index c907b41c5..5bebf60f0 100644 --- a/egs/librispeech/SSL/zipformer/finetune.py +++ b/egs/librispeech/SSL/zipformer/finetune.py @@ -81,6 +81,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, @@ -1115,7 +1116,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1504,7 +1505,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/zipformer/model.py b/egs/librispeech/SSL/zipformer/model.py index 46a968b69..2c2077376 100644 --- a/egs/librispeech/SSL/zipformer/model.py +++ b/egs/librispeech/SSL/zipformer/model.py @@ -24,7 +24,7 @@ import torch import torch.nn as nn from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class AsrModel(nn.Module): @@ -221,7 +221,7 @@ class AsrModel(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -256,7 +256,7 @@ class AsrModel(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/SSL/zipformer/pretrain.py b/egs/librispeech/SSL/zipformer/pretrain.py index 937fb382e..d772f56d0 100644 --- a/egs/librispeech/SSL/zipformer/pretrain.py +++ b/egs/librispeech/SSL/zipformer/pretrain.py @@ -78,6 +78,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -944,7 +945,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1334,7 +1335,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/zipformer/zipformer.py b/egs/librispeech/SSL/zipformer/zipformer.py index e9eff3357..5071a91a8 100644 --- a/egs/librispeech/SSL/zipformer/zipformer.py +++ b/egs/librispeech/SSL/zipformer/zipformer.py @@ -22,6 +22,7 @@ import math import random import warnings from typing import List, Optional, Tuple, Union +from icefall.utils import torch_autocast import torch from encoder_interface import EncoderInterface @@ -1849,7 +1850,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/WSASR/conformer_ctc2/train.py b/egs/librispeech/WSASR/conformer_ctc2/train.py index 82c68803f..19cce1708 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/train.py +++ b/egs/librispeech/WSASR/conformer_ctc2/train.py @@ -84,6 +84,7 @@ from icefall.utils import ( get_texts, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -757,7 +758,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1076,7 +1077,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py index b276d0587..a32183bf7 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py +++ b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py @@ -79,6 +79,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.otc_phone_graph_compiler import OtcPhoneTrainingGraphCompiler from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, encode_supervisions_otc, @@ -758,7 +759,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1078,7 +1079,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py index 296f9a4f4..906025b7f 100755 --- a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py +++ b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py @@ -73,6 +73,8 @@ def compute_fbank_ljspeech(num_jobs: int): f_min=0, f_max=8000, ) + if not torch.cuda.is_available(): + config.device = "cpu" prefix = "ljspeech" suffix = "jsonl.gz" diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py index e0a94bf08..3de7136ec 100755 --- a/egs/yesno/ASR/local/compile_hlg.py +++ b/egs/yesno/ASR/local/compile_hlg.py @@ -47,7 +47,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: max_token_id = max(lexicon.tokens) logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False)) logging.info("Loading G.fst.txt") with open("data/lm/G.fst.txt") as f: diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index f520607af..479e195fa 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -271,7 +271,9 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) assert HLG.requires_grad is False diff --git a/egs/yesno/ASR/tdnn/jit_pretrained.py b/egs/yesno/ASR/tdnn/jit_pretrained.py index 6c643c263..2a0879045 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained.py @@ -131,7 +131,9 @@ def main(): model.to(device) logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) logging.info("Constructing Fbank computer") diff --git a/egs/yesno/ASR/tdnn/onnx_pretrained.py b/egs/yesno/ASR/tdnn/onnx_pretrained.py index 968a9e9a8..e6471d2db 100755 --- a/egs/yesno/ASR/tdnn/onnx_pretrained.py +++ b/egs/yesno/ASR/tdnn/onnx_pretrained.py @@ -176,7 +176,9 @@ def main(): model = OnnxModel(params.nn_model) logging.info(f"Loading HLG from {args.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) logging.info("Constructing Fbank computer") diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py index bea520998..d4f3ae39f 100755 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ b/egs/yesno/ASR/tdnn/pretrained.py @@ -148,13 +148,15 @@ def main(): num_classes=params.num_classes, ) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"]) model.to(device) model.eval() logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) logging.info("Constructing Fbank computer") diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index d31ce1301..4ab685684 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -27,7 +27,6 @@ import torch import torch.nn as nn from lhotse.dataset.sampling.base import CutSampler from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -43,7 +42,7 @@ def save_checkpoint( params: Optional[Dict[str, Any]] = None, optimizer: Optional[Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, sampler: Optional[CutSampler] = None, rank: int = 0, ) -> None: @@ -102,7 +101,7 @@ def load_checkpoint( model_avg: Optional[nn.Module] = None, optimizer: Optional[Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, sampler: Optional[CutSampler] = None, strict: bool = False, ) -> Dict[str, Any]: @@ -110,7 +109,7 @@ def load_checkpoint( TODO: document it """ logging.info(f"Loading checkpoint from {filename}") - checkpoint = torch.load(filename, map_location="cpu") + checkpoint = torch.load(filename, map_location="cpu", weights_only=False) if next(iter(checkpoint["model"])).startswith("module."): logging.info("Loading checkpoint saved by DDP") @@ -163,7 +162,7 @@ def average_checkpoints( """ n = len(filenames) - avg = torch.load(filenames[0], map_location=device)["model"] + avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"] # Identify shared parameters. Two parameters are said to be shared # if they have the same data_ptr @@ -178,7 +177,9 @@ def average_checkpoints( uniqued_names = list(uniqued.values()) for i in range(1, n): - state_dict = torch.load(filenames[i], map_location=device)["model"] + state_dict = torch.load(filenames[i], map_location=device, weights_only=False)[ + "model" + ] for k in uniqued_names: avg[k] += state_dict[k] @@ -199,7 +200,7 @@ def save_checkpoint_with_global_batch_idx( params: Optional[Dict[str, Any]] = None, optimizer: Optional[Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, sampler: Optional[CutSampler] = None, rank: int = 0, ): @@ -421,8 +422,10 @@ def average_checkpoints_with_averaged_model( device: Move checkpoints to this device before averaging. """ - state_dict_start = torch.load(filename_start, map_location=device) - state_dict_end = torch.load(filename_end, map_location=device) + state_dict_start = torch.load( + filename_start, map_location=device, weights_only=False + ) + state_dict_end = torch.load(filename_end, map_location=device, weights_only=False) average_period = state_dict_start["average_period"] diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index e5eaba619..d923e8842 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -631,7 +631,10 @@ def attach_diagnostics( ) module.register_forward_hook(forward_hook) - module.register_backward_hook(backward_hook) + if hasattr(module, "register_full_backward_hook"): + module.register_full_backward_hook(backward_hook) + else: + module.register_backward_hook(backward_hook) if type(module).__name__ in [ "Sigmoid", @@ -665,7 +668,10 @@ def attach_diagnostics( _model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output) module.register_forward_hook(scalar_forward_hook) - module.register_backward_hook(scalar_backward_hook) + if hasattr(module, "register_full_backward_hook"): + module.register_full_backward_hook(scalar_backward_hook) + else: + module.register_backward_hook(scalar_backward_hook) for name, parameter in model.named_parameters(): diff --git a/icefall/hooks.py b/icefall/hooks.py index 85583acbe..b543190be 100644 --- a/icefall/hooks.py +++ b/icefall/hooks.py @@ -77,7 +77,11 @@ def register_inf_check_hooks(model: nn.Module) -> None: logging.warning(f"The sum of {_name}.grad[{i}] is not finite") module.register_forward_hook(forward_hook) - module.register_backward_hook(backward_hook) + + if hasattr(module, "register_full_backward_hook"): + module.register_full_backward_hook(backward_hook) + else: + module.register_backward_hook(backward_hook) for name, parameter in model.named_parameters(): diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index 0178b80bf..023afb5a5 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -53,7 +53,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) def get_parser(): @@ -401,7 +407,7 @@ def compute_validation_loss( for batch_idx, batch in enumerate(valid_dl): x, y, sentence_lengths = batch - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, @@ -470,7 +476,7 @@ def train_one_epoch( params.batch_idx_train += 1 x, y, sentence_lengths = batch batch_size = x.size(0) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, diff --git a/icefall/transformer_lm/train.py b/icefall/transformer_lm/train.py index c36abfcdf..acec95e94 100644 --- a/icefall/transformer_lm/train.py +++ b/icefall/transformer_lm/train.py @@ -50,7 +50,13 @@ from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) def get_parser(): @@ -341,7 +347,7 @@ def compute_validation_loss( for batch_idx, batch in enumerate(valid_dl): x, y, sentence_lengths = batch - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, @@ -403,7 +409,7 @@ def train_one_epoch( params.batch_idx_train += 1 x, y, sentence_lengths = batch batch_size = x.size(0) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, diff --git a/icefall/utils.py b/icefall/utils.py index ffb926566..022f83b3b 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -26,6 +26,7 @@ import pathlib import random import re import subprocess +import warnings from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass @@ -42,6 +43,7 @@ import torch import torch.distributed as dist import torch.nn as nn from lhotse.dataset.signal_transforms import time_warp as time_warp_impl +from packaging import version from pypinyin import lazy_pinyin, pinyin from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials from torch.utils.tensorboard import SummaryWriter @@ -50,6 +52,48 @@ from icefall.checkpoint import average_checkpoints Pathlike = Union[str, Path] +TORCH_VERSION = version.parse(torch.__version__) + + +def create_grad_scaler(device="cuda", **kwargs): + """ + Creates a GradScaler compatible with both torch < 2.3.0 and >= 2.3.0. + Accepts all kwargs like: enabled, init_scale, growth_factor, etc. + + /icefall/egs/librispeech/ASR/./zipformer/train.py:1451: FutureWarning: + `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use + `torch.amp.GradScaler('cuda', args...)` instead. + """ + if TORCH_VERSION >= version.parse("2.3.0"): + from torch.amp import GradScaler + + return GradScaler(device=device, **kwargs) + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + return torch.cuda.amp.GradScaler(**kwargs) + + +@contextmanager +def torch_autocast(device_type="cuda", **kwargs): + """ + To fix the following warnings: + /icefall/egs/librispeech/ASR/zipformer/model.py:323: + FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. + Please use `torch.amp.autocast('cuda', args...)` instead. + with torch.cuda.amp.autocast(enabled=False): + """ + if TORCH_VERSION >= version.parse("2.3.0"): + # Use new unified API + with torch.amp.autocast(device_type=device_type, **kwargs): + yield + else: + # Suppress deprecation warning and use old CUDA-specific autocast + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + with torch.cuda.amp.autocast(**kwargs): + yield + # Pytorch issue: https://github.com/pytorch/pytorch/issues/47379 # Fixed: https://github.com/pytorch/pytorch/pull/49853 @@ -1551,6 +1595,7 @@ def optim_step_and_measure_param_change( and the L2 norm of the original parameter. It is given by the formula: .. math:: + \begin{aligned} \delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2} \end{aligned} diff --git a/requirements.txt b/requirements.txt index d97263142..885bf2fc3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ flake8==5.0.4 # cantonese word segment support pycantonese==3.4.0 +packaging