From f186e1d42730984b35e4e6013eee4cdc9cd72d7b Mon Sep 17 00:00:00 2001 From: k2-fsa Date: Mon, 30 Jun 2025 22:07:36 +0800 Subject: [PATCH] Fix weights_only=False --- .github/scripts/docker/generate_build_matrix.py | 2 +- .github/workflows/aishell.yml | 2 +- .github/workflows/librispeech.yml | 3 ++- egs/librispeech/ASR/conformer_ctc/decode.py | 8 ++++++-- egs/librispeech/ASR/conformer_ctc/pretrained.py | 10 +++++++--- egs/librispeech/ASR/conformer_ctc2/decode.py | 8 ++++++-- egs/librispeech/ASR/conformer_ctc3/decode.py | 8 ++++++-- egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py | 8 ++++++-- egs/librispeech/ASR/conformer_ctc3/pretrained.py | 10 +++++++--- egs/librispeech/ASR/conformer_mmi/decode.py | 8 ++++++-- egs/librispeech/ASR/local/compile_hlg.py | 4 ++-- egs/librispeech/ASR/local/compile_lg.py | 4 ++-- .../ASR/lstm_transducer_stateless/decode.py | 2 +- .../ASR/lstm_transducer_stateless/pretrained.py | 2 +- .../ASR/lstm_transducer_stateless2/decode.py | 2 +- .../ASR/lstm_transducer_stateless2/pretrained.py | 2 +- .../ASR/lstm_transducer_stateless3/decode.py | 2 +- .../ASR/lstm_transducer_stateless3/pretrained.py | 2 +- .../ASR/pruned_transducer_stateless/decode.py | 2 +- .../ASR/pruned_transducer_stateless2/decode.py | 2 +- .../ASR/pruned_transducer_stateless2/pretrained.py | 2 +- .../ASR/pruned_transducer_stateless3/decode.py | 4 ++-- .../ASR/pruned_transducer_stateless3/pretrained.py | 2 +- .../ASR/pruned_transducer_stateless4/decode.py | 2 +- .../ASR/pruned_transducer_stateless5/decode.py | 2 +- .../ASR/pruned_transducer_stateless5/pretrained.py | 2 +- .../ASR/pruned_transducer_stateless6/vq_utils.py | 4 +++- .../ASR/pruned_transducer_stateless7/compute_ali.py | 2 +- .../ASR/pruned_transducer_stateless7/decode.py | 2 +- .../decode_gigaspeech.py | 2 +- .../ASR/pruned_transducer_stateless7/finetune.py | 2 +- .../ASR/pruned_transducer_stateless7/pretrained.py | 2 +- .../pretrained_ctc.py | 8 ++++++-- .../ctc_decode.py | 8 ++++++-- .../ctc_guide_decode_bs.py | 2 +- .../pruned_transducer_stateless7_ctc_bs/decode.py | 2 +- .../jit_pretrained_ctc.py | 8 ++++++-- .../pretrained.py | 2 +- .../pretrained_ctc.py | 10 +++++++--- .../decode.py | 2 +- .../decode_gigaspeech.py | 2 +- .../ASR/pruned_transducer_stateless8/decode.py | 2 +- .../ASR/pruned_transducer_stateless8/pretrained.py | 2 +- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 8 ++++++-- egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py | 10 +++++++--- .../ASR/tiny_transducer_ctc/ctc_decode.py | 8 ++++++-- egs/librispeech/ASR/tiny_transducer_ctc/decode.py | 2 +- .../ASR/tiny_transducer_ctc/jit_pretrained_ctc.py | 8 ++++++-- .../ASR/tiny_transducer_ctc/pretrained.py | 2 +- .../ASR/tiny_transducer_ctc/pretrained_ctc.py | 10 +++++++--- egs/librispeech/ASR/transducer/pretrained.py | 2 +- egs/librispeech/ASR/zipformer_adapter/decode.py | 2 +- .../ASR/zipformer_adapter/decode_gigaspeech.py | 2 +- egs/librispeech/ASR/zipformer_adapter/train.py | 2 +- .../ASR/zipformer_lora/decode_gigaspeech.py | 2 +- egs/librispeech/ASR/zipformer_lora/finetune.py | 2 +- egs/librispeech/ASR/zipformer_mmi/decode.py | 10 ++++++++-- egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py | 8 ++++++-- egs/librispeech/ASR/zipformer_mmi/pretrained.py | 10 +++++++--- icefall/diagnostics.py | 10 ++++++++-- icefall/hooks.py | 6 +++++- icefall/transformer_lm/train.py | 12 +++++++++--- 62 files changed, 190 insertions(+), 93 deletions(-) diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index 395e7ef37..7f36e278d 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -91,7 +91,7 @@ def get_matrix(min_torch_version, specified_torch_version, specified_python_vers matrix = [] for p in python_version: for t in torch_version: - if min_torch_version and version_ge(min_torch_version, t): + if min_torch_version and version_gt(min_torch_version, t): continue # torchaudio <= 1.13.x supports only python <= 3.10 diff --git a/.github/workflows/aishell.yml b/.github/workflows/aishell.yml index 57224040b..4572c0c7f 100644 --- a/.github/workflows/aishell.yml +++ b/.github/workflows/aishell.yml @@ -17,7 +17,7 @@ concurrency: jobs: generate_build_matrix: - if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'aishell') + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' # see https://github.com/pytorch/pytorch/pull/50633 runs-on: ubuntu-latest diff --git a/.github/workflows/librispeech.yml b/.github/workflows/librispeech.yml index 77b2a4bd4..4b8021254 100644 --- a/.github/workflows/librispeech.yml +++ b/.github/workflows/librispeech.yml @@ -30,7 +30,8 @@ jobs: run: | # outputting for debugging purposes python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" - MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.6.0") + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10") + # MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.6.0") echo "::set-output name=matrix::${MATRIX}" librispeech: needs: generate_build_matrix diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 7e0bf5b7b..fc866f83b 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -667,7 +667,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -707,7 +709,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.method in [ diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 38b60fcb9..5b3a021ad 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -271,7 +271,7 @@ def main(): use_feat_batchnorm=params.use_feat_batchnorm, ) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -351,7 +351,9 @@ def main(): "attention-decoder", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -362,7 +364,9 @@ def main(): "attention-decoder", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = G.to(device) diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 0b271a51c..349e8f02d 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -774,7 +774,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -814,7 +816,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.method in [ diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index e6327bb5e..cf58fd18d 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -868,7 +868,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -907,7 +909,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py index 19b26361e..f8e3fa43b 100755 --- a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py @@ -334,7 +334,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -345,7 +347,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/conformer_ctc3/pretrained.py b/egs/librispeech/ASR/conformer_ctc3/pretrained.py index a0cdfcf03..e528b2cb8 100755 --- a/egs/librispeech/ASR/conformer_ctc3/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc3/pretrained.py @@ -290,7 +290,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -386,7 +386,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -397,7 +399,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index 74f6e73fa..01fcf0685 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -574,7 +574,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + torch.load( + f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False + ) ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -609,7 +611,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False + ) G = k2.Fsa.from_dict(d).to(device) if params.method in ["whole-lattice-rescoring", "attention-decoder"]: diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index d19d50ae6..ec39d5b36 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -72,11 +72,11 @@ def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: max_token_id = max(lexicon.tokens) logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False)) if Path(f"data/lm/{lm}.pt").is_file(): logging.info(f"Loading pre-compiled {lm}") - d = torch.load(f"data/lm/{lm}.pt") + d = torch.load(f"data/lm/{lm}.pt", weights_only=False) G = k2.Fsa.from_dict(d) else: logging.info(f"Loading {lm}.fst.txt") diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index 709b14070..bd25cfa29 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -66,11 +66,11 @@ def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: An FSA representing LG. """ lexicon = Lexicon(lang_dir) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False)) if Path(f"data/lm/{lm}.pt").is_file(): logging.info(f"Loading pre-compiled {lm}") - d = torch.load(f"data/lm/{lm}.pt") + d = torch.load(f"data/lm/{lm}.pt", weights_only=False) G = k2.Fsa.from_dict(d) else: logging.info(f"Loading {lm}.fst.txt") diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py index 856c9d945..8c75eb871 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -750,7 +750,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py index 42c3a5d7f..f29d1d9db 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -238,7 +238,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 1a724830b..cfbbb334c 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -935,7 +935,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index dcff088e2..888f9931e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -241,7 +241,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index a2b4f9e1a..e25b79e2e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -815,7 +815,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index e39637bd8..619e783b0 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -239,7 +239,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 3c4500087..6d1da7440 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -741,7 +741,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index c57514193..5a4a74ebb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -754,7 +754,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index 6923f4d40..e6ddcab25 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -265,7 +265,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 7c62bfa58..18a3792b0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -921,7 +921,7 @@ def load_ngram_LM( if pt_file.is_file(): logging.info(f"Loading pre-compiled {pt_file}") - d = torch.load(pt_file, map_location=device) + d = torch.load(pt_file, map_location=device, weights_only=False) G = k2.Fsa.from_dict(d) G = k2.add_epsilon_self_loops(G) G = k2.arc_sort(G) @@ -1101,7 +1101,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale elif params.decoding_method in [ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 05e6a6fba..19143fb5d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -274,7 +274,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 5195a4ef6..925c01c7b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -913,7 +913,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 7a3e63218..404d7a3d3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -972,7 +972,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index a9ce75a7b..9e2669379 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -238,7 +238,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 3bca7db2c..4f3fbaa81 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -348,7 +348,9 @@ class CodebookIndexExtractor: num_codebooks=self.params.num_codebooks, codebook_size=256, ) - quantizer.load_state_dict(torch.load(self.quantizer_file_path)) + quantizer.load_state_dict( + torch.load(self.quantizer_file_path, weights_only=False) + ) quantizer.to(self.params.device) return quantizer diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py index 27ef0a244..949a497ce 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py @@ -289,7 +289,7 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index eb8841cc4..048de7bb9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -910,7 +910,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py index 7095c3cc8..da1bf17fc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py @@ -813,7 +813,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index a530c74ae..072aa274c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -636,7 +636,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index 4bf11ac24..fabda3aaa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py index c5b1f2558..32242c94e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py @@ -365,7 +365,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -376,7 +378,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py index fa7144f0f..3af3ada2c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py @@ -624,7 +624,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -663,7 +665,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py index e2f08abc6..233f00236 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py @@ -808,7 +808,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py index e497787d3..025b146b9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py @@ -786,7 +786,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py index 80604ef4a..70d9841bf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py @@ -347,7 +347,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -358,7 +360,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py index a82f3562b..9ceec5f5a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py index b98756a54..431760f9a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py @@ -286,7 +286,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -362,7 +362,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -373,7 +375,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py index 35158ced4..61c1a9663 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py @@ -768,7 +768,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py index a4f52ad7f..e95bb3357 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py @@ -788,7 +788,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index e07777c9f..3cad83a0b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -747,7 +747,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index c29b8d8c9..693db2beb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 92529e06c..db12ab827 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -398,7 +398,9 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -428,7 +430,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False + ) G = k2.Fsa.from_dict(d).to(device) if params.method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py index b3dfab64a..4ad7cb016 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py @@ -167,13 +167,15 @@ def main(): subsampling_factor=params.subsampling_factor, ) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"]) model.to(device) model.eval() logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -181,7 +183,9 @@ def main(): if params.method == "whole-lattice-rescoring": logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = G.to(device) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py index cda03b56e..ec700626a 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py @@ -589,7 +589,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -628,7 +630,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py index cc4471e2b..1b329e8f3 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py @@ -663,7 +663,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py index 92dea3aa1..4b234a328 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py @@ -347,7 +347,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -358,7 +360,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py index 5c6956324..9714aa537 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py @@ -249,7 +249,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py index 7698ada79..a2ea1dd06 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py @@ -286,7 +286,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -365,7 +365,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -376,7 +378,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py index 4d9bbf4b1..06b1c05b9 100755 --- a/egs/librispeech/ASR/transducer/pretrained.py +++ b/egs/librispeech/ASR/transducer/pretrained.py @@ -222,7 +222,7 @@ def main(): logging.info("Creating model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/zipformer_adapter/decode.py b/egs/librispeech/ASR/zipformer_adapter/decode.py index 91533be8d..e8798aed6 100755 --- a/egs/librispeech/ASR/zipformer_adapter/decode.py +++ b/egs/librispeech/ASR/zipformer_adapter/decode.py @@ -1005,7 +1005,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py index bbc582f50..66c401761 100755 --- a/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py +++ b/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py @@ -1050,7 +1050,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py index d744d59d2..99c852844 100755 --- a/egs/librispeech/ASR/zipformer_adapter/train.py +++ b/egs/librispeech/ASR/zipformer_adapter/train.py @@ -763,7 +763,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: diff --git a/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py index 4d93a905f..acc814a00 100755 --- a/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py +++ b/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py @@ -1050,7 +1050,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/zipformer_lora/finetune.py b/egs/librispeech/ASR/zipformer_lora/finetune.py index ca9002928..ea6e2877b 100755 --- a/egs/librispeech/ASR/zipformer_lora/finetune.py +++ b/egs/librispeech/ASR/zipformer_lora/finetune.py @@ -776,7 +776,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: diff --git a/egs/librispeech/ASR/zipformer_mmi/decode.py b/egs/librispeech/ASR/zipformer_mmi/decode.py index 33c0bf199..bd3ce21f5 100755 --- a/egs/librispeech/ASR/zipformer_mmi/decode.py +++ b/egs/librispeech/ASR/zipformer_mmi/decode.py @@ -569,7 +569,9 @@ def main(): if params.decoding_method == "nbest-rescoring-LG": lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") - LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device)) + LG = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device, weights_only=False) + ) LG = k2.Fsa.from_fsas([LG]).to(device) LG.lm_scores = LG.scores.clone() @@ -602,7 +604,11 @@ def main(): torch.save(G.as_dict(), params.lang_dir / f"{order}gram.pt") else: logging.info(f"Loading pre-compiled {order}gram.pt") - d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device) + d = torch.load( + params.lang_dir / f"{order}gram.pt", + map_location=device, + weights_only=False, + ) G = k2.Fsa.from_dict(d) G.lm_scores = G.scores.clone() diff --git a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py index 6990c90a0..d5667cafa 100755 --- a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py +++ b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py @@ -308,7 +308,9 @@ def main(): if method == "nbest-rescoring-LG": lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") - LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device)) + LG = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device, weights_only=False) + ) LG = k2.Fsa.from_fsas([LG]).to(device) LG.lm_scores = LG.scores.clone() LM = LG @@ -317,7 +319,9 @@ def main(): assert order in ("3", "4") order = int(order) logging.info(f"Loading pre-compiled {order}gram.pt") - d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device) + d = torch.load( + params.lang_dir / f"{order}gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) G.lm_scores = G.scores.clone() LM = G diff --git a/egs/librispeech/ASR/zipformer_mmi/pretrained.py b/egs/librispeech/ASR/zipformer_mmi/pretrained.py index 1e7afc777..ca860b877 100755 --- a/egs/librispeech/ASR/zipformer_mmi/pretrained.py +++ b/egs/librispeech/ASR/zipformer_mmi/pretrained.py @@ -269,7 +269,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -331,7 +331,9 @@ def main(): if method == "nbest-rescoring-LG": lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") - LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device)) + LG = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device, weights_only=False) + ) LG = k2.Fsa.from_fsas([LG]).to(device) LG.lm_scores = LG.scores.clone() LM = LG @@ -340,7 +342,9 @@ def main(): assert order in ("3", "4") order = int(order) logging.info(f"Loading pre-compiled {order}gram.pt") - d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device) + d = torch.load( + params.lang_dir / f"{order}gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) G.lm_scores = G.scores.clone() LM = G diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index e5eaba619..d923e8842 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -631,7 +631,10 @@ def attach_diagnostics( ) module.register_forward_hook(forward_hook) - module.register_backward_hook(backward_hook) + if hasattr(module, "register_full_backward_hook"): + module.register_full_backward_hook(backward_hook) + else: + module.register_backward_hook(backward_hook) if type(module).__name__ in [ "Sigmoid", @@ -665,7 +668,10 @@ def attach_diagnostics( _model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output) module.register_forward_hook(scalar_forward_hook) - module.register_backward_hook(scalar_backward_hook) + if hasattr(module, "register_full_backward_hook"): + module.register_full_backward_hook(scalar_backward_hook) + else: + module.register_backward_hook(scalar_backward_hook) for name, parameter in model.named_parameters(): diff --git a/icefall/hooks.py b/icefall/hooks.py index 85583acbe..b543190be 100644 --- a/icefall/hooks.py +++ b/icefall/hooks.py @@ -77,7 +77,11 @@ def register_inf_check_hooks(model: nn.Module) -> None: logging.warning(f"The sum of {_name}.grad[{i}] is not finite") module.register_forward_hook(forward_hook) - module.register_backward_hook(backward_hook) + + if hasattr(module, "register_full_backward_hook"): + module.register_full_backward_hook(backward_hook) + else: + module.register_backward_hook(backward_hook) for name, parameter in model.named_parameters(): diff --git a/icefall/transformer_lm/train.py b/icefall/transformer_lm/train.py index c36abfcdf..acec95e94 100644 --- a/icefall/transformer_lm/train.py +++ b/icefall/transformer_lm/train.py @@ -50,7 +50,13 @@ from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) def get_parser(): @@ -341,7 +347,7 @@ def compute_validation_loss( for batch_idx, batch in enumerate(valid_dl): x, y, sentence_lengths = batch - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, @@ -403,7 +409,7 @@ def train_one_epoch( params.batch_idx_train += 1 x, y, sentence_lengths = batch batch_size = x.size(0) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x,