add weights_only=False to torch.load (#1984)

This commit is contained in:
Teo Wen Shen 2025-07-10 16:27:08 +09:00 committed by GitHub
parent 89728dd4f8
commit da87e7fc99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
141 changed files with 205 additions and 205 deletions

View File

@ -41,7 +41,7 @@ To give you an idea of what ``tdnn/exp/pretrained.pt`` contains, we can use the
.. code-block:: python3 .. code-block:: python3
>>> import torch >>> import torch
>>> m = torch.load("tdnn/exp/pretrained.pt") >>> m = torch.load("tdnn/exp/pretrained.pt", weights_only=False)
>>> list(m.keys()) >>> list(m.keys())
['model'] ['model']
>>> list(m["model"].keys()) >>> list(m["model"].keys())

View File

@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
4. Generate L.pt, in k2 format. It can be loaded by 4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt") d = torch.load("L.pt", weights_only=False)
lexicon = k2.Fsa.from_dict(d) lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format. 5. Generate L_disambig.pt, in k2 format.

View File

@ -224,7 +224,7 @@ def main():
logging.info("Creating model") logging.info("Creating model")
model = get_transducer_model(params) model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -503,7 +503,7 @@ def main():
else: else:
H = None H = None
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
) )
assert HLG.requires_grad is False assert HLG.requires_grad is False

View File

@ -249,7 +249,7 @@ def main():
use_feat_batchnorm=params.use_feat_batchnorm, use_feat_batchnorm=params.use_feat_batchnorm,
) )
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()
@ -315,7 +315,7 @@ def main():
hyps = [[token_sym_table[i] for i in ids] for ids in token_ids] hyps = [[token_sym_table[i] for i in ids] for ids in token_ids]
elif params.method in ["1best", "attention-decoder"]: elif params.method in ["1best", "attention-decoder"]:
logging.info(f"Loading HLG from {params.HLG}") logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False))
HLG = HLG.to(device) HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"): if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder # For whole-lattice-rescoring and attention-decoder

View File

@ -516,7 +516,7 @@ def main():
else: else:
H = None H = None
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
) )
assert HLG.requires_grad is False assert HLG.requires_grad is False

View File

@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
4. Generate L.pt, in k2 format. It can be loaded by 4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt") d = torch.load("L.pt", weights_only=False)
lexicon = k2.Fsa.from_dict(d) lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format. 5. Generate L_disambig.pt, in k2 format.

View File

@ -227,7 +227,7 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -228,7 +228,7 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -773,7 +773,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -237,7 +237,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -337,7 +337,7 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False))
HLG = HLG.to(device) HLG = HLG.to(device)
assert HLG.requires_grad is False assert HLG.requires_grad is False

View File

@ -139,13 +139,13 @@ def main():
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
) )
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
model.to(device) model.to(device)
model.eval() model.eval()
logging.info(f"Loading HLG from {params.HLG}") logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False))
HLG = HLG.to(device) HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"): if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder # For whole-lattice-rescoring and attention-decoder

View File

@ -245,7 +245,7 @@ def main():
logging.info("Creating model") logging.info("Creating model")
model = get_transducer_model(params) model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -225,7 +225,7 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -225,7 +225,7 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -89,10 +89,10 @@ def average_checkpoints(
""" """
n = len(filenames) n = len(filenames)
if "model" in torch.load(filenames[0], map_location=device): if "model" in torch.load(filenames[0], map_location=device, weights_only=False):
avg = torch.load(filenames[0], map_location=device)["model"] avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"]
else: else:
avg = torch.load(filenames[0], map_location=device) avg = torch.load(filenames[0], map_location=device, weights_only=False)
# Identify shared parameters. Two parameters are said to be shared # Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr # if they have the same data_ptr
@ -107,10 +107,10 @@ def average_checkpoints(
uniqued_names = list(uniqued.values()) uniqued_names = list(uniqued.values())
for i in range(1, n): for i in range(1, n):
if "model" in torch.load(filenames[i], map_location=device): if "model" in torch.load(filenames[i], map_location=device, weights_only=False):
state_dict = torch.load(filenames[i], map_location=device)["model"] state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"]
else: else:
state_dict = torch.load(filenames[i], map_location=device) state_dict = torch.load(filenames[i], map_location=device, weights_only=False)
for k in uniqued_names: for k in uniqued_names:
avg[k] += state_dict[k] avg[k] += state_dict[k]
@ -440,7 +440,7 @@ def main():
start = params.epoch - params.avg start = params.epoch - params.avg
assert start >= 1, start assert start >= 1, start
checkpoint = torch.load( checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False
) )
if "model" not in checkpoint: if "model" not in checkpoint:
# deepspeed converted checkpoint only contains model state_dict # deepspeed converted checkpoint only contains model state_dict
@ -469,7 +469,7 @@ def main():
torch.save(model.state_dict(), filename) torch.save(model.state_dict(), filename)
else: else:
checkpoint = torch.load( checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False
) )
if "model" not in checkpoint: if "model" not in checkpoint:
model.load_state_dict(checkpoint, strict=True) model.load_state_dict(checkpoint, strict=True)

View File

@ -761,7 +761,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -783,7 +783,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -298,7 +298,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -728,7 +728,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -226,7 +226,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
4. Generate L.pt, in k2 format. It can be loaded by 4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt") d = torch.load("L.pt", weights_only=False)
lexicon = k2.Fsa.from_dict(d) lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format. 5. Generate L_disambig.pt, in k2 format.

View File

@ -238,7 +238,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
4. Generate L.pt, in k2 format. It can be loaded by 4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt") d = torch.load("L.pt", weights_only=False)
lexicon = k2.Fsa.from_dict(d) lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format. 5. Generate L_disambig.pt, in k2 format.

View File

@ -224,7 +224,7 @@ def main():
logging.info("Creating model") logging.info("Creating model")
model = get_transducer_model(params) model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -672,7 +672,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -1263,7 +1263,7 @@ def run(rank, world_size, args):
logging.info( logging.info(
f"Initializing model with checkpoint from {params.model_init_ckpt}" f"Initializing model with checkpoint from {params.model_init_ckpt}"
) )
init_ckpt = torch.load(params.model_init_ckpt, map_location=device) init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False)
model.load_state_dict(init_ckpt["model"], strict=False) model.load_state_dict(init_ckpt["model"], strict=False)
if world_size > 1: if world_size > 1:

View File

@ -1254,7 +1254,7 @@ def run(rank, world_size, args):
logging.info( logging.info(
f"Initializing model with checkpoint from {params.model_init_ckpt}" f"Initializing model with checkpoint from {params.model_init_ckpt}"
) )
init_ckpt = torch.load(params.model_init_ckpt, map_location=device) init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False)
model.load_state_dict(init_ckpt["model"], strict=False) model.load_state_dict(init_ckpt["model"], strict=False)
if world_size > 1: if world_size > 1:

View File

@ -141,7 +141,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -115,7 +115,7 @@ def load_vocoder(checkpoint_path: Path) -> nn.Module:
hifigan = HiFiGAN(h).to("cpu") hifigan = HiFiGAN(h).to("cpu")
hifigan.load_state_dict( hifigan.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["generator"] torch.load(checkpoint_path, map_location="cpu", weights_only=False)["generator"]
) )
_ = hifigan.eval() _ = hifigan.eval()
hifigan.remove_weight_norm() hifigan.remove_weight_norm()

View File

@ -73,11 +73,11 @@ def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
max_token_id = max(lexicon.tokens) max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id) H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
if Path(f"{lang_dir}/lm/{lm}.pt").is_file(): if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}") logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"{lang_dir}/lm/{lm}.pt") d = torch.load(f"{lang_dir}/lm/{lm}.pt", weights_only=False)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
else: else:
logging.info(f"Loading {lm}.fst.txt") logging.info(f"Loading {lm}.fst.txt")

View File

@ -68,11 +68,11 @@ def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
An FSA representing LG. An FSA representing LG.
""" """
lexicon = Lexicon(lang_dir) lexicon = Lexicon(lang_dir)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
if Path(f"{lang_dir}/lm/{lm}.pt").is_file(): if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}") logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"{lang_dir}/lm/{lm}.pt") d = torch.load(f"{lang_dir}/lm/{lm}.pt", weights_only=False)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
else: else:
logging.info(f"Loading {lm}.fst.txt") logging.info(f"Loading {lm}.fst.txt")

View File

@ -910,7 +910,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -247,7 +247,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -767,7 +767,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -627,7 +627,7 @@ def load_model_params(
""" """
logging.info(f"Loading checkpoint from {ckpt}") logging.info(f"Loading checkpoint from {ckpt}")
checkpoint = torch.load(ckpt, map_location="cpu") checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False)
# if module list is empty, load the whole model from ckpt # if module list is empty, load the whole model from ckpt
if not init_modules: if not init_modules:

View File

@ -25,7 +25,7 @@ Usage:
--exp-dir ./pruned_transducer_stateless7/exp --exp-dir ./pruned_transducer_stateless7/exp
It will generate a file `epoch-28-avg-15-use-averaged-model.pt` in the given `exp_dir`. It will generate a file `epoch-28-avg-15-use-averaged-model.pt` in the given `exp_dir`.
You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`. You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt", weights_only=False)`.
(2) use the averaged model with checkpoint exp_dir/checkpoint-iter.pt (2) use the averaged model with checkpoint exp_dir/checkpoint-iter.pt
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ ./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
@ -35,7 +35,7 @@ You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`.
--exp-dir ./pruned_transducer_stateless7/exp --exp-dir ./pruned_transducer_stateless7/exp
It will generate a file `iter-22000-avg-5-use-averaged-model.pt` in the given `exp_dir`. It will generate a file `iter-22000-avg-5-use-averaged-model.pt` in the given `exp_dir`.
You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`. You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt", weights_only=False)`.
(3) use the original model with checkpoint exp_dir/epoch-xxx.pt (3) use the original model with checkpoint exp_dir/epoch-xxx.pt
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ ./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
@ -45,7 +45,7 @@ You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`.
--exp-dir ./pruned_transducer_stateless7/exp --exp-dir ./pruned_transducer_stateless7/exp
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
You can later load it by `torch.load("epoch-28-avg-15.pt")`. You can later load it by `torch.load("epoch-28-avg-15.pt", weights_only=False)`.
(4) use the original model with checkpoint exp_dir/checkpoint-iter.pt (4) use the original model with checkpoint exp_dir/checkpoint-iter.pt
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ ./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
@ -55,7 +55,7 @@ You can later load it by `torch.load("epoch-28-avg-15.pt")`.
--exp-dir ./pruned_transducer_stateless7/exp --exp-dir ./pruned_transducer_stateless7/exp
It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
You can later load it by `torch.load("iter-22000-avg-5.pt")`. You can later load it by `torch.load("iter-22000-avg-5.pt", weights_only=False)`.
""" """

View File

@ -987,7 +987,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -756,7 +756,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -791,7 +791,7 @@ def main():
if params.decoding_graph: if params.decoding_graph:
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(params.decoding_graph, map_location=device) torch.load(params.decoding_graph, map_location=device, weights_only=False)
) )
elif "fast_beam_search" in params.decoding_method: elif "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG": if params.decoding_method == "fast_beam_search_nbest_LG":
@ -800,7 +800,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -239,7 +239,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -561,7 +561,7 @@ def main():
decoding_graph = None decoding_graph = None
if params.decoding_graph: if params.decoding_graph:
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(params.decoding_graph, map_location=device) torch.load(params.decoding_graph, map_location=device, weights_only=False)
) )
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)

View File

@ -47,7 +47,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
max_token_id = max(lexicon.tokens) max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id) H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
logging.info("Loading G.fst.txt") logging.info("Loading G.fst.txt")
with open(lang_dir / "G.fst.txt") as f: with open(lang_dir / "G.fst.txt") as f:

View File

@ -14,7 +14,7 @@ consisting of words and tokens (i.e., phones) and does the following:
4. Generate L.pt, in k2 format. It can be loaded by 4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt") d = torch.load("L.pt", weights_only=False)
lexicon = k2.Fsa.from_dict(d) lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format. 5. Generate L_disambig.pt, in k2 format.

View File

@ -589,7 +589,7 @@ def main():
H = None H = None
bpe_model = None bpe_model = None
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
) )
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -628,7 +628,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else: else:
logging.info("Loading pre-compiled G_4_gram.pt") logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]: if params.method in ["whole-lattice-rescoring", "attention-decoder"]:

View File

@ -668,7 +668,7 @@ def main():
H = None H = None
bpe_model = None bpe_model = None
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
) )
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -707,7 +707,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else: else:
logging.info("Loading pre-compiled G_4_gram.pt") logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
if params.decoding_method == "whole-lattice-rescoring": if params.decoding_method == "whole-lattice-rescoring":

View File

@ -1000,7 +1000,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -1001,7 +1001,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -183,7 +183,7 @@ def load_model_params(
""" """
logging.info(f"Loading checkpoint from {ckpt}") logging.info(f"Loading checkpoint from {ckpt}")
checkpoint = torch.load(ckpt, map_location="cpu") checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False)
# if module list is empty, load the whole model from ckpt # if module list is empty, load the whole model from ckpt
if not init_modules: if not init_modules:

View File

@ -938,7 +938,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -666,7 +666,7 @@ def main():
H = None H = None
bpe_model = None bpe_model = None
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
) )
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -705,7 +705,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else: else:
logging.info("Loading pre-compiled G_4_gram.pt") logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
if params.decoding_method == "whole-lattice-rescoring": if params.decoding_method == "whole-lattice-rescoring":

View File

@ -989,7 +989,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -177,7 +177,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -1286,7 +1286,7 @@ def run(rank, world_size, args):
logging.info( logging.info(
f"Initializing model with checkpoint from {params.model_init_ckpt}" f"Initializing model with checkpoint from {params.model_init_ckpt}"
) )
init_ckpt = torch.load(params.model_init_ckpt, map_location=device) init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False)
model.load_state_dict(init_ckpt["model"], strict=False) model.load_state_dict(init_ckpt["model"], strict=False)
if world_size > 1: if world_size > 1:

View File

@ -1175,7 +1175,7 @@ def run(rank, world_size, args):
logging.info( logging.info(
f"Initializing model with checkpoint from {params.model_init_ckpt}" f"Initializing model with checkpoint from {params.model_init_ckpt}"
) )
init_ckpt = torch.load(params.model_init_ckpt, map_location=device) init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False)
model.load_state_dict(init_ckpt["model"], strict=True) model.load_state_dict(init_ckpt["model"], strict=True)
if world_size > 1: if world_size > 1:

View File

@ -252,7 +252,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -960,7 +960,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -750,7 +750,7 @@ def _to_int_tuple(s: str):
def get_encoder_model(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module:
if hasattr(params, "pretrained_dir"): if hasattr(params, "pretrained_dir"):
logging.info(f"Loading {params.pretrained_dir}") logging.info(f"Loading {params.pretrained_dir}")
pretrained = torch.load(params.pretrained_dir) pretrained = torch.load(params.pretrained_dir, weights_only=False)
encoder = HubertModel(params) encoder = HubertModel(params)
encoder.load_state_dict(pretrained["model"]) encoder.load_state_dict(pretrained["model"])
else: else:

View File

@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
4. Generate L.pt, in k2 format. It can be loaded by 4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt") d = torch.load("L.pt", weights_only=False)
lexicon = k2.Fsa.from_dict(d) lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format. 5. Generate L_disambig.pt, in k2 format.

View File

@ -264,7 +264,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -234,7 +234,7 @@ def main():
logging.info("Creating model") logging.info("Creating model")
model = get_transducer_model(params) model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -234,7 +234,7 @@ def main():
logging.info("Creating model") logging.info("Creating model")
model = get_transducer_model(params) model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -234,7 +234,7 @@ def main():
logging.info("Creating model") logging.info("Creating model")
model = get_transducer_model(params) model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -962,7 +962,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -962,7 +962,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -451,7 +451,7 @@ def _to_int_tuple(s: str):
def get_encoder_model(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module:
if hasattr(params, "pretrained_dir"): if hasattr(params, "pretrained_dir"):
logging.info(f"Loading {params.pretrained_dir}") logging.info(f"Loading {params.pretrained_dir}")
pretrained = torch.load(params.pretrained_dir) pretrained = torch.load(params.pretrained_dir, weights_only=False)
encoder = HubertModel(params) encoder = HubertModel(params)
encoder.load_state_dict(pretrained["model"]) encoder.load_state_dict(pretrained["model"])
else: else:

View File

@ -451,7 +451,7 @@ def _to_int_tuple(s: str):
def get_encoder_model(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module:
if hasattr(params, "pretrained_dir"): if hasattr(params, "pretrained_dir"):
logging.info(f"Loading {params.pretrained_dir}") logging.info(f"Loading {params.pretrained_dir}")
pretrained = torch.load(params.pretrained_dir) pretrained = torch.load(params.pretrained_dir, weights_only=False)
encoder = HubertModel(params) encoder = HubertModel(params)
encoder.load_state_dict(pretrained["model"]) encoder.load_state_dict(pretrained["model"])
else: else:

View File

@ -12,7 +12,7 @@ args = parser.parse_args()
src = args.src src = args.src
tgt = args.tgt tgt = args.tgt
old_checkpoint = torch.load(src) old_checkpoint = torch.load(src, weights_only=False)
new_checkpoint = OrderedDict() new_checkpoint = OrderedDict()
new_checkpoint["model"] = old_checkpoint["model"] new_checkpoint["model"] = old_checkpoint["model"]
torch.save(new_checkpoint, tgt) torch.save(new_checkpoint, tgt)

View File

@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
4. Generate L.pt, in k2 format. It can be loaded by 4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt") d = torch.load("L.pt", weights_only=False)
lexicon = k2.Fsa.from_dict(d) lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format. 5. Generate L_disambig.pt, in k2 format.

View File

@ -960,7 +960,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -750,7 +750,7 @@ def _to_int_tuple(s: str):
def get_encoder_model(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module:
if hasattr(params, "pretrained_dir"): if hasattr(params, "pretrained_dir"):
logging.info(f"Loading {params.pretrained_dir}") logging.info(f"Loading {params.pretrained_dir}")
pretrained = torch.load(params.pretrained_dir) pretrained = torch.load(params.pretrained_dir, weights_only=False)
encoder = HubertModel(params) encoder = HubertModel(params)
encoder.load_state_dict(pretrained["model"]) encoder.load_state_dict(pretrained["model"])
else: else:

View File

@ -578,7 +578,7 @@ def main():
H = None H = None
bpe_model = None bpe_model = None
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
) )
assert HLG.requires_grad is False assert HLG.requires_grad is False

View File

@ -457,7 +457,7 @@ def main():
params.num_classes = num_classes params.num_classes = num_classes
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False))
HLG = HLG.to(device) HLG = HLG.to(device)
assert HLG.requires_grad is False assert HLG.requires_grad is False

View File

@ -78,11 +78,11 @@ def compile_HLG(lm_dir: str, lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
max_token_id = max(lexicon.tokens) max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id) H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
if Path(f"{lm_dir}/{lm}.pt").is_file(): if Path(f"{lm_dir}/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}") logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"{lm_dir}/{lm}.pt") d = torch.load(f"{lm_dir}/{lm}.pt", weights_only=False)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
else: else:
logging.info(f"Loading {lm}.fst.txt") logging.info(f"Loading {lm}.fst.txt")

View File

@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
4. Generate L.pt, in k2 format. It can be loaded by 4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt") d = torch.load("L.pt", weights_only=False)
lexicon = k2.Fsa.from_dict(d) lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format. 5. Generate L_disambig.pt, in k2 format.

View File

@ -29,7 +29,7 @@ consisting of words and tokens (i.e., phones) and does the following:
4. Generate L.pt, in k2 format. It can be loaded by 4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt") d = torch.load("L.pt", weights_only=False)
lexicon = k2.Fsa.from_dict(d) lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format. 5. Generate L_disambig.pt, in k2 format.

View File

@ -802,7 +802,7 @@ def main():
H = None H = None
bpe_model = None bpe_model = None
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
) )
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -842,7 +842,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else: else:
logging.info("Loading pre-compiled G_4_gram.pt") logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
if params.decoding_method in [ if params.decoding_method in [

View File

@ -1014,7 +1014,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -41,7 +41,7 @@ def get_padding(kernel_size, dilation=1):
def load_checkpoint(filepath, device): def load_checkpoint(filepath, device):
assert os.path.isfile(filepath) assert os.path.isfile(filepath)
print(f"Loading '{filepath}'") print(f"Loading '{filepath}'")
checkpoint_dict = torch.load(filepath, map_location=device) checkpoint_dict = torch.load(filepath, map_location=device, weights_only=False)
print("Complete.") print("Complete.")
return checkpoint_dict return checkpoint_dict

View File

@ -103,7 +103,7 @@ def load_vocoder(checkpoint_path: Path) -> nn.Module:
hifigan = HiFiGAN(h).to("cpu") hifigan = HiFiGAN(h).to("cpu")
hifigan.load_state_dict( hifigan.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["generator"] torch.load(checkpoint_path, map_location="cpu", weights_only=False)["generator"]
) )
_ = hifigan.eval() _ = hifigan.eval()
hifigan.remove_weight_norm() hifigan.remove_weight_norm()

View File

@ -756,7 +756,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -575,7 +575,7 @@ def main():
H = None H = None
bpe_model = None bpe_model = None
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
) )
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -614,7 +614,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else: else:
logging.info("Loading pre-compiled G_4_gram.pt") logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]: if params.method in ["whole-lattice-rescoring", "attention-decoder"]:

View File

@ -275,7 +275,7 @@ def main():
use_feat_batchnorm=params.use_feat_batchnorm, use_feat_batchnorm=params.use_feat_batchnorm,
) )
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()
@ -347,7 +347,7 @@ def main():
"attention-decoder", "attention-decoder",
]: ]:
logging.info(f"Loading HLG from {params.HLG}") logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False))
HLG = HLG.to(device) HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"): if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder # For whole-lattice-rescoring and attention-decoder
@ -358,7 +358,7 @@ def main():
"attention-decoder", "attention-decoder",
]: ]:
logging.info(f"Loading G from {params.G}") logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu", weights_only=False))
# Add epsilon self-loops to G as we will compose # Add epsilon self-loops to G as we will compose
# it with the whole lattice later # it with the whole lattice later
G = G.to(device) G = G.to(device)

View File

@ -236,7 +236,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -733,7 +733,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -90,10 +90,10 @@ def average_checkpoints(
""" """
n = len(filenames) n = len(filenames)
if "model" in torch.load(filenames[0], map_location=device): if "model" in torch.load(filenames[0], map_location=device, weights_only=False):
avg = torch.load(filenames[0], map_location=device)["model"] avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"]
else: else:
avg = torch.load(filenames[0], map_location=device) avg = torch.load(filenames[0], map_location=device, weights_only=False)
# Identify shared parameters. Two parameters are said to be shared # Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr # if they have the same data_ptr
@ -108,10 +108,10 @@ def average_checkpoints(
uniqued_names = list(uniqued.values()) uniqued_names = list(uniqued.values())
for i in range(1, n): for i in range(1, n):
if "model" in torch.load(filenames[i], map_location=device): if "model" in torch.load(filenames[i], map_location=device, weights_only=False):
state_dict = torch.load(filenames[i], map_location=device)["model"] state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"]
else: else:
state_dict = torch.load(filenames[i], map_location=device) state_dict = torch.load(filenames[i], map_location=device, weights_only=False)
for k in uniqued_names: for k in uniqued_names:
avg[k] += state_dict[k] avg[k] += state_dict[k]
@ -484,7 +484,7 @@ def main():
start = params.epoch - params.avg start = params.epoch - params.avg
assert start >= 1, start assert start >= 1, start
checkpoint = torch.load( checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False
) )
if "model" not in checkpoint: if "model" not in checkpoint:
# deepspeed converted checkpoint only contains model state_dict # deepspeed converted checkpoint only contains model state_dict
@ -513,7 +513,7 @@ def main():
torch.save(model.state_dict(), filename) torch.save(model.state_dict(), filename)
else: else:
checkpoint = torch.load( checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False
) )
if "model" not in checkpoint: if "model" not in checkpoint:
model.load_state_dict(checkpoint, strict=True) model.load_state_dict(checkpoint, strict=True)

View File

@ -809,7 +809,7 @@ def run(rank, world_size, args):
del model.alignment_heads del model.alignment_heads
if params.pretrained_model_path: if params.pretrained_model_path:
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") checkpoint = torch.load(params.pretrained_model_path, map_location="cpu", weights_only=False)
if "model" not in checkpoint: if "model" not in checkpoint:
model.load_state_dict(checkpoint, strict=True) model.load_state_dict(checkpoint, strict=True)
else: else:

View File

@ -784,7 +784,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -24,7 +24,7 @@ Usage:
--exp-dir ./zipformer/exp --exp-dir ./zipformer/exp
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
You can later load it by `torch.load("epoch-28-avg-15.pt")`. You can later load it by `torch.load("epoch-28-avg-15.pt", weights_only=False)`.
(2) use the checkpoint exp_dir/checkpoint-iter.pt (2) use the checkpoint exp_dir/checkpoint-iter.pt
./zipformer/generate_averaged_model.py \ ./zipformer/generate_averaged_model.py \
@ -33,7 +33,7 @@ You can later load it by `torch.load("epoch-28-avg-15.pt")`.
--exp-dir ./zipformer/exp --exp-dir ./zipformer/exp
It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
You can later load it by `torch.load("iter-22000-avg-5.pt")`. You can later load it by `torch.load("iter-22000-avg-5.pt", weights_only=False)`.
""" """

View File

@ -291,7 +291,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -792,7 +792,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -24,7 +24,7 @@ Usage:
--exp-dir ./zipformer/exp --exp-dir ./zipformer/exp
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
You can later load it by `torch.load("epoch-28-avg-15.pt")`. You can later load it by `torch.load("epoch-28-avg-15.pt", weights_only=False)`.
(2) use the checkpoint exp_dir/checkpoint-iter.pt (2) use the checkpoint exp_dir/checkpoint-iter.pt
./zipformer/generate_averaged_model.py \ ./zipformer/generate_averaged_model.py \
@ -33,7 +33,7 @@ You can later load it by `torch.load("epoch-28-avg-15.pt")`.
--exp-dir ./zipformer/exp --exp-dir ./zipformer/exp
It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
You can later load it by `torch.load("iter-22000-avg-5.pt")`. You can later load it by `torch.load("iter-22000-avg-5.pt", weights_only=False)`.
""" """

View File

@ -294,7 +294,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -64,7 +64,7 @@ def main():
if out_lm_data.is_file(): if out_lm_data.is_file():
logging.warning(f"{out_lm_data} exists - skipping") logging.warning(f"{out_lm_data} exists - skipping")
return return
data = torch.load(in_lm_data) data = torch.load(in_lm_data, weights_only=False)
words2bpe = data["words"] words2bpe = data["words"]
sentences = data["sentences"] sentences = data["sentences"]
sentence_lengths = data["sentence_lengths"] sentence_lengths = data["sentence_lengths"]

View File

@ -37,7 +37,7 @@ def main():
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(str(bpe_model)) sp.load(str(bpe_model))
data = torch.load(lm_training_data) data = torch.load(lm_training_data, weights_only=False)
words2bpe = data["words"] words2bpe = data["words"]
sentences = data["sentences"] sentences = data["sentences"]

View File

@ -1008,7 +1008,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -95,10 +95,10 @@ def average_checkpoints(
""" """
n = len(filenames) n = len(filenames)
if "model" in torch.load(filenames[0], map_location=device): if "model" in torch.load(filenames[0], map_location=device, weights_only=False):
avg = torch.load(filenames[0], map_location=device)["model"] avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"]
else: else:
avg = torch.load(filenames[0], map_location=device) avg = torch.load(filenames[0], map_location=device, weights_only=False)
# Identify shared parameters. Two parameters are said to be shared # Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr # if they have the same data_ptr
@ -113,10 +113,10 @@ def average_checkpoints(
uniqued_names = list(uniqued.values()) uniqued_names = list(uniqued.values())
for i in range(1, n): for i in range(1, n):
if "model" in torch.load(filenames[i], map_location=device): if "model" in torch.load(filenames[i], map_location=device, weights_only=False):
state_dict = torch.load(filenames[i], map_location=device)["model"] state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"]
else: else:
state_dict = torch.load(filenames[i], map_location=device) state_dict = torch.load(filenames[i], map_location=device, weights_only=False)
for k in uniqued_names: for k in uniqued_names:
avg[k] += state_dict[k] avg[k] += state_dict[k]
@ -548,7 +548,7 @@ def main():
# torch.save(avg_checkpoint, filename) # torch.save(avg_checkpoint, filename)
else: else:
checkpoint = torch.load( checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin", f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin", weights_only=False,
map_location="cpu", map_location="cpu",
) )
model.load_state_dict(checkpoint, strict=False) model.load_state_dict(checkpoint, strict=False)

View File

@ -652,7 +652,7 @@ def run(rank, world_size, args):
) )
if params.pretrained_model_path: if params.pretrained_model_path:
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") checkpoint = torch.load(params.pretrained_model_path, map_location="cpu", weights_only=False)
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
@ -704,7 +704,7 @@ def run(rank, world_size, args):
sampler_state_dict = None sampler_state_dict = None
if params.sampler_state_dict_path: if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path) sampler_state_dict = torch.load(params.sampler_state_dict_path, weights_only=False)
sampler_state_dict["max_duration"] = params.max_duration sampler_state_dict["max_duration"] = params.max_duration
train_dl = data_module.train_dataloaders( train_dl = data_module.train_dataloaders(

View File

@ -91,10 +91,10 @@ def average_checkpoints(
""" """
n = len(filenames) n = len(filenames)
if "model" in torch.load(filenames[0], map_location=device): if "model" in torch.load(filenames[0], map_location=device, weights_only=False):
avg = torch.load(filenames[0], map_location=device)["model"] avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"]
else: else:
avg = torch.load(filenames[0], map_location=device) avg = torch.load(filenames[0], map_location=device, weights_only=False)
# Identify shared parameters. Two parameters are said to be shared # Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr # if they have the same data_ptr
@ -109,10 +109,10 @@ def average_checkpoints(
uniqued_names = list(uniqued.values()) uniqued_names = list(uniqued.values())
for i in range(1, n): for i in range(1, n):
if "model" in torch.load(filenames[i], map_location=device): if "model" in torch.load(filenames[i], map_location=device, weights_only=False):
state_dict = torch.load(filenames[i], map_location=device)["model"] state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"]
else: else:
state_dict = torch.load(filenames[i], map_location=device) state_dict = torch.load(filenames[i], map_location=device, weights_only=False)
for k in uniqued_names: for k in uniqued_names:
avg[k] += state_dict[k] avg[k] += state_dict[k]
@ -447,7 +447,7 @@ def main():
start = params.epoch - params.avg start = params.epoch - params.avg
assert start >= 1, start assert start >= 1, start
checkpoint = torch.load( checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False
) )
if "model" not in checkpoint: if "model" not in checkpoint:
# deepspeed converted checkpoint only contains model state_dict # deepspeed converted checkpoint only contains model state_dict
@ -476,7 +476,7 @@ def main():
torch.save(model.state_dict(), filename) torch.save(model.state_dict(), filename)
else: else:
checkpoint = torch.load( checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False
) )
if "model" not in checkpoint: if "model" not in checkpoint:
model.load_state_dict(checkpoint, strict=True) model.load_state_dict(checkpoint, strict=True)

Some files were not shown because too many files have changed in this diff Show More