Fix weights_only=False

This commit is contained in:
k2-fsa 2025-06-30 22:07:36 +08:00
parent a53c323750
commit f186e1d427
62 changed files with 190 additions and 93 deletions

View File

@ -91,7 +91,7 @@ def get_matrix(min_torch_version, specified_torch_version, specified_python_vers
matrix = []
for p in python_version:
for t in torch_version:
if min_torch_version and version_ge(min_torch_version, t):
if min_torch_version and version_gt(min_torch_version, t):
continue
# torchaudio <= 1.13.x supports only python <= 3.10

View File

@ -17,7 +17,7 @@ concurrency:
jobs:
generate_build_matrix:
if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'aishell')
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest

View File

@ -30,7 +30,8 @@ jobs:
run: |
# outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.6.0")
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
# MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.6.0")
echo "::set-output name=matrix::${MATRIX}"
librispeech:
needs: generate_build_matrix

View File

@ -667,7 +667,9 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
torch.load(
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
)
)
assert HLG.requires_grad is False
@ -707,7 +709,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
d = torch.load(
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
)
G = k2.Fsa.from_dict(d)
if params.method in [

View File

@ -271,7 +271,7 @@ def main():
use_feat_batchnorm=params.use_feat_batchnorm,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
@ -351,7 +351,9 @@ def main():
"attention-decoder",
]:
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = k2.Fsa.from_dict(
torch.load(params.HLG, map_location="cpu", weights_only=False)
)
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@ -362,7 +364,9 @@ def main():
"attention-decoder",
]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
G = k2.Fsa.from_dict(
torch.load(params.G, map_location="cpu", weights_only=False)
)
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = G.to(device)

View File

@ -774,7 +774,9 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
torch.load(
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
)
)
assert HLG.requires_grad is False
@ -814,7 +816,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
d = torch.load(
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
)
G = k2.Fsa.from_dict(d)
if params.method in [

View File

@ -868,7 +868,9 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
torch.load(
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
)
)
assert HLG.requires_grad is False
@ -907,7 +909,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
d = torch.load(
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
)
G = k2.Fsa.from_dict(d)
if params.decoding_method == "whole-lattice-rescoring":

View File

@ -334,7 +334,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = k2.Fsa.from_dict(
torch.load(params.HLG, map_location="cpu", weights_only=False)
)
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@ -345,7 +347,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
G = k2.Fsa.from_dict(
torch.load(params.G, map_location="cpu", weights_only=False)
)
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose

View File

@ -290,7 +290,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
@ -386,7 +386,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = k2.Fsa.from_dict(
torch.load(params.HLG, map_location="cpu", weights_only=False)
)
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@ -397,7 +399,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
G = k2.Fsa.from_dict(
torch.load(params.G, map_location="cpu", weights_only=False)
)
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose

View File

@ -574,7 +574,9 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
torch.load(
f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False
)
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
@ -609,7 +611,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
d = torch.load(
params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False
)
G = k2.Fsa.from_dict(d).to(device)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:

View File

@ -72,11 +72,11 @@ def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
if Path(f"data/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"data/lm/{lm}.pt")
d = torch.load(f"data/lm/{lm}.pt", weights_only=False)
G = k2.Fsa.from_dict(d)
else:
logging.info(f"Loading {lm}.fst.txt")

View File

@ -66,11 +66,11 @@ def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
An FSA representing LG.
"""
lexicon = Lexicon(lang_dir)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
if Path(f"data/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"data/lm/{lm}.pt")
d = torch.load(f"data/lm/{lm}.pt", weights_only=False)
G = k2.Fsa.from_dict(d)
else:
logging.info(f"Loading {lm}.fst.txt")

View File

@ -750,7 +750,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -238,7 +238,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -935,7 +935,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -241,7 +241,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -815,7 +815,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -239,7 +239,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -741,7 +741,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -754,7 +754,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -265,7 +265,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -921,7 +921,7 @@ def load_ngram_LM(
if pt_file.is_file():
logging.info(f"Loading pre-compiled {pt_file}")
d = torch.load(pt_file, map_location=device)
d = torch.load(pt_file, map_location=device, weights_only=False)
G = k2.Fsa.from_dict(d)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
@ -1101,7 +1101,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
elif params.decoding_method in [

View File

@ -274,7 +274,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -913,7 +913,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -972,7 +972,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -238,7 +238,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -348,7 +348,9 @@ class CodebookIndexExtractor:
num_codebooks=self.params.num_codebooks,
codebook_size=256,
)
quantizer.load_state_dict(torch.load(self.quantizer_file_path))
quantizer.load_state_dict(
torch.load(self.quantizer_file_path, weights_only=False)
)
quantizer.to(self.params.device)
return quantizer

View File

@ -289,7 +289,7 @@ def main():
logging.info("About to create model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -910,7 +910,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -813,7 +813,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -636,7 +636,7 @@ def load_model_params(
"""
logging.info(f"Loading checkpoint from {ckpt}")
checkpoint = torch.load(ckpt, map_location="cpu")
checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False)
# if module list is empty, load the whole model from ckpt
if not init_modules:

View File

@ -247,7 +247,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -365,7 +365,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = k2.Fsa.from_dict(
torch.load(params.HLG, map_location="cpu", weights_only=False)
)
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@ -376,7 +378,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
G = k2.Fsa.from_dict(
torch.load(params.G, map_location="cpu", weights_only=False)
)
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose

View File

@ -624,7 +624,9 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
torch.load(
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
)
)
assert HLG.requires_grad is False
@ -663,7 +665,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
d = torch.load(
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
)
G = k2.Fsa.from_dict(d)
if params.decoding_method == "whole-lattice-rescoring":

View File

@ -808,7 +808,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -786,7 +786,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -347,7 +347,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = k2.Fsa.from_dict(
torch.load(params.HLG, map_location="cpu", weights_only=False)
)
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@ -358,7 +360,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
G = k2.Fsa.from_dict(
torch.load(params.G, map_location="cpu", weights_only=False)
)
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose

View File

@ -247,7 +247,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -286,7 +286,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
@ -362,7 +362,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = k2.Fsa.from_dict(
torch.load(params.HLG, map_location="cpu", weights_only=False)
)
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@ -373,7 +375,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
G = k2.Fsa.from_dict(
torch.load(params.G, map_location="cpu", weights_only=False)
)
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose

View File

@ -768,7 +768,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -788,7 +788,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -747,7 +747,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -247,7 +247,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -398,7 +398,9 @@ def main():
logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False)
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
@ -428,7 +430,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
d = torch.load(
params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False
)
G = k2.Fsa.from_dict(d).to(device)
if params.method == "whole-lattice-rescoring":

View File

@ -167,13 +167,15 @@ def main():
subsampling_factor=params.subsampling_factor,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = k2.Fsa.from_dict(
torch.load(params.HLG, map_location="cpu", weights_only=False)
)
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@ -181,7 +183,9 @@ def main():
if params.method == "whole-lattice-rescoring":
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
G = k2.Fsa.from_dict(
torch.load(params.G, map_location="cpu", weights_only=False)
)
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = G.to(device)

View File

@ -589,7 +589,9 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
torch.load(
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
)
)
assert HLG.requires_grad is False
@ -628,7 +630,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
d = torch.load(
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
)
G = k2.Fsa.from_dict(d)
if params.decoding_method == "whole-lattice-rescoring":

View File

@ -663,7 +663,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -347,7 +347,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = k2.Fsa.from_dict(
torch.load(params.HLG, map_location="cpu", weights_only=False)
)
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@ -358,7 +360,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
G = k2.Fsa.from_dict(
torch.load(params.G, map_location="cpu", weights_only=False)
)
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose

View File

@ -249,7 +249,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -286,7 +286,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
@ -365,7 +365,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = k2.Fsa.from_dict(
torch.load(params.HLG, map_location="cpu", weights_only=False)
)
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@ -376,7 +378,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
G = k2.Fsa.from_dict(
torch.load(params.G, map_location="cpu", weights_only=False)
)
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose

View File

@ -222,7 +222,7 @@ def main():
logging.info("Creating model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -1005,7 +1005,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -1050,7 +1050,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -763,7 +763,7 @@ def load_model_params(
"""
logging.info(f"Loading checkpoint from {ckpt}")
checkpoint = torch.load(ckpt, map_location="cpu")
checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False)
# if module list is empty, load the whole model from ckpt
if not init_modules:

View File

@ -1050,7 +1050,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -776,7 +776,7 @@ def load_model_params(
"""
logging.info(f"Loading checkpoint from {ckpt}")
checkpoint = torch.load(ckpt, map_location="cpu")
checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False)
# if module list is empty, load the whole model from ckpt
if not init_modules:

View File

@ -569,7 +569,9 @@ def main():
if params.decoding_method == "nbest-rescoring-LG":
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device))
LG = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device, weights_only=False)
)
LG = k2.Fsa.from_fsas([LG]).to(device)
LG.lm_scores = LG.scores.clone()
@ -602,7 +604,11 @@ def main():
torch.save(G.as_dict(), params.lang_dir / f"{order}gram.pt")
else:
logging.info(f"Loading pre-compiled {order}gram.pt")
d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device)
d = torch.load(
params.lang_dir / f"{order}gram.pt",
map_location=device,
weights_only=False,
)
G = k2.Fsa.from_dict(d)
G.lm_scores = G.scores.clone()

View File

@ -308,7 +308,9 @@ def main():
if method == "nbest-rescoring-LG":
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device))
LG = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device, weights_only=False)
)
LG = k2.Fsa.from_fsas([LG]).to(device)
LG.lm_scores = LG.scores.clone()
LM = LG
@ -317,7 +319,9 @@ def main():
assert order in ("3", "4")
order = int(order)
logging.info(f"Loading pre-compiled {order}gram.pt")
d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device)
d = torch.load(
params.lang_dir / f"{order}gram.pt", map_location=device, weights_only=False
)
G = k2.Fsa.from_dict(d)
G.lm_scores = G.scores.clone()
LM = G

View File

@ -269,7 +269,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
@ -331,7 +331,9 @@ def main():
if method == "nbest-rescoring-LG":
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device))
LG = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device, weights_only=False)
)
LG = k2.Fsa.from_fsas([LG]).to(device)
LG.lm_scores = LG.scores.clone()
LM = LG
@ -340,7 +342,9 @@ def main():
assert order in ("3", "4")
order = int(order)
logging.info(f"Loading pre-compiled {order}gram.pt")
d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device)
d = torch.load(
params.lang_dir / f"{order}gram.pt", map_location=device, weights_only=False
)
G = k2.Fsa.from_dict(d)
G.lm_scores = G.scores.clone()
LM = G

View File

@ -631,7 +631,10 @@ def attach_diagnostics(
)
module.register_forward_hook(forward_hook)
module.register_backward_hook(backward_hook)
if hasattr(module, "register_full_backward_hook"):
module.register_full_backward_hook(backward_hook)
else:
module.register_backward_hook(backward_hook)
if type(module).__name__ in [
"Sigmoid",
@ -665,7 +668,10 @@ def attach_diagnostics(
_model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output)
module.register_forward_hook(scalar_forward_hook)
module.register_backward_hook(scalar_backward_hook)
if hasattr(module, "register_full_backward_hook"):
module.register_full_backward_hook(scalar_backward_hook)
else:
module.register_backward_hook(scalar_backward_hook)
for name, parameter in model.named_parameters():

View File

@ -77,7 +77,11 @@ def register_inf_check_hooks(model: nn.Module) -> None:
logging.warning(f"The sum of {_name}.grad[{i}] is not finite")
module.register_forward_hook(forward_hook)
module.register_backward_hook(backward_hook)
if hasattr(module, "register_full_backward_hook"):
module.register_full_backward_hook(backward_hook)
else:
module.register_backward_hook(backward_hook)
for name, parameter in model.named_parameters():

View File

@ -50,7 +50,13 @@ from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
def get_parser():
@ -341,7 +347,7 @@ def compute_validation_loss(
for batch_idx, batch in enumerate(valid_dl):
x, y, sentence_lengths = batch
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
model=model,
x=x,
@ -403,7 +409,7 @@ def train_one_epoch(
params.batch_idx_train += 1
x, y, sentence_lengths = batch
batch_size = x.size(0)
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
model=model,
x=x,