add weights_only=False to torch.load (#1984)

This commit is contained in:
Teo Wen Shen 2025-07-10 16:27:08 +09:00 committed by GitHub
parent 89728dd4f8
commit da87e7fc99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
141 changed files with 205 additions and 205 deletions

View File

@ -41,7 +41,7 @@ To give you an idea of what ``tdnn/exp/pretrained.pt`` contains, we can use the
.. code-block:: python3
>>> import torch
>>> m = torch.load("tdnn/exp/pretrained.pt")
>>> m = torch.load("tdnn/exp/pretrained.pt", weights_only=False)
>>> list(m.keys())
['model']
>>> list(m["model"].keys())

View File

@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
d = torch.load("L.pt", weights_only=False)
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.

View File

@ -224,7 +224,7 @@ def main():
logging.info("Creating model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -503,7 +503,7 @@ def main():
else:
H = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
)
assert HLG.requires_grad is False

View File

@ -249,7 +249,7 @@ def main():
use_feat_batchnorm=params.use_feat_batchnorm,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
@ -315,7 +315,7 @@ def main():
hyps = [[token_sym_table[i] for i in ids] for ids in token_ids]
elif params.method in ["1best", "attention-decoder"]:
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder

View File

@ -516,7 +516,7 @@ def main():
else:
H = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
)
assert HLG.requires_grad is False

View File

@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
d = torch.load("L.pt", weights_only=False)
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.

View File

@ -227,7 +227,7 @@ def main():
logging.info("About to create model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -228,7 +228,7 @@ def main():
logging.info("About to create model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -773,7 +773,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -237,7 +237,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -337,7 +337,7 @@ def main():
logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False))
HLG = HLG.to(device)
assert HLG.requires_grad is False

View File

@ -139,13 +139,13 @@ def main():
subsampling_factor=params.subsampling_factor,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder

View File

@ -245,7 +245,7 @@ def main():
logging.info("Creating model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -225,7 +225,7 @@ def main():
logging.info("About to create model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -225,7 +225,7 @@ def main():
logging.info("About to create model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()

View File

@ -89,10 +89,10 @@ def average_checkpoints(
"""
n = len(filenames)
if "model" in torch.load(filenames[0], map_location=device):
avg = torch.load(filenames[0], map_location=device)["model"]
if "model" in torch.load(filenames[0], map_location=device, weights_only=False):
avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"]
else:
avg = torch.load(filenames[0], map_location=device)
avg = torch.load(filenames[0], map_location=device, weights_only=False)
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
@ -107,10 +107,10 @@ def average_checkpoints(
uniqued_names = list(uniqued.values())
for i in range(1, n):
if "model" in torch.load(filenames[i], map_location=device):
state_dict = torch.load(filenames[i], map_location=device)["model"]
if "model" in torch.load(filenames[i], map_location=device, weights_only=False):
state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"]
else:
state_dict = torch.load(filenames[i], map_location=device)
state_dict = torch.load(filenames[i], map_location=device, weights_only=False)
for k in uniqued_names:
avg[k] += state_dict[k]
@ -440,7 +440,7 @@ def main():
start = params.epoch - params.avg
assert start >= 1, start
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False
)
if "model" not in checkpoint:
# deepspeed converted checkpoint only contains model state_dict
@ -469,7 +469,7 @@ def main():
torch.save(model.state_dict(), filename)
else:
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False
)
if "model" not in checkpoint:
model.load_state_dict(checkpoint, strict=True)

View File

@ -761,7 +761,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -783,7 +783,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -298,7 +298,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -728,7 +728,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -226,7 +226,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
d = torch.load("L.pt", weights_only=False)
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.

View File

@ -238,7 +238,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
d = torch.load("L.pt", weights_only=False)
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.

View File

@ -224,7 +224,7 @@ def main():
logging.info("Creating model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -672,7 +672,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -1263,7 +1263,7 @@ def run(rank, world_size, args):
logging.info(
f"Initializing model with checkpoint from {params.model_init_ckpt}"
)
init_ckpt = torch.load(params.model_init_ckpt, map_location=device)
init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False)
model.load_state_dict(init_ckpt["model"], strict=False)
if world_size > 1:

View File

@ -1254,7 +1254,7 @@ def run(rank, world_size, args):
logging.info(
f"Initializing model with checkpoint from {params.model_init_ckpt}"
)
init_ckpt = torch.load(params.model_init_ckpt, map_location=device)
init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False)
model.load_state_dict(init_ckpt["model"], strict=False)
if world_size > 1:

View File

@ -141,7 +141,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -115,7 +115,7 @@ def load_vocoder(checkpoint_path: Path) -> nn.Module:
hifigan = HiFiGAN(h).to("cpu")
hifigan.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["generator"]
torch.load(checkpoint_path, map_location="cpu", weights_only=False)["generator"]
)
_ = hifigan.eval()
hifigan.remove_weight_norm()

View File

@ -73,11 +73,11 @@ def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
d = torch.load(f"{lang_dir}/lm/{lm}.pt", weights_only=False)
G = k2.Fsa.from_dict(d)
else:
logging.info(f"Loading {lm}.fst.txt")

View File

@ -68,11 +68,11 @@ def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
An FSA representing LG.
"""
lexicon = Lexicon(lang_dir)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
d = torch.load(f"{lang_dir}/lm/{lm}.pt", weights_only=False)
G = k2.Fsa.from_dict(d)
else:
logging.info(f"Loading {lm}.fst.txt")

View File

@ -910,7 +910,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -247,7 +247,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -767,7 +767,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -627,7 +627,7 @@ def load_model_params(
"""
logging.info(f"Loading checkpoint from {ckpt}")
checkpoint = torch.load(ckpt, map_location="cpu")
checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False)
# if module list is empty, load the whole model from ckpt
if not init_modules:

View File

@ -25,7 +25,7 @@ Usage:
--exp-dir ./pruned_transducer_stateless7/exp
It will generate a file `epoch-28-avg-15-use-averaged-model.pt` in the given `exp_dir`.
You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`.
You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt", weights_only=False)`.
(2) use the averaged model with checkpoint exp_dir/checkpoint-iter.pt
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
@ -35,7 +35,7 @@ You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`.
--exp-dir ./pruned_transducer_stateless7/exp
It will generate a file `iter-22000-avg-5-use-averaged-model.pt` in the given `exp_dir`.
You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`.
You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt", weights_only=False)`.
(3) use the original model with checkpoint exp_dir/epoch-xxx.pt
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
@ -45,7 +45,7 @@ You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`.
--exp-dir ./pruned_transducer_stateless7/exp
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
You can later load it by `torch.load("epoch-28-avg-15.pt")`.
You can later load it by `torch.load("epoch-28-avg-15.pt", weights_only=False)`.
(4) use the original model with checkpoint exp_dir/checkpoint-iter.pt
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
@ -55,7 +55,7 @@ You can later load it by `torch.load("epoch-28-avg-15.pt")`.
--exp-dir ./pruned_transducer_stateless7/exp
It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
You can later load it by `torch.load("iter-22000-avg-5.pt")`.
You can later load it by `torch.load("iter-22000-avg-5.pt", weights_only=False)`.
"""

View File

@ -987,7 +987,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -756,7 +756,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -791,7 +791,7 @@ def main():
if params.decoding_graph:
decoding_graph = k2.Fsa.from_dict(
torch.load(params.decoding_graph, map_location=device)
torch.load(params.decoding_graph, map_location=device, weights_only=False)
)
elif "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
@ -800,7 +800,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -239,7 +239,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -561,7 +561,7 @@ def main():
decoding_graph = None
if params.decoding_graph:
decoding_graph = k2.Fsa.from_dict(
torch.load(params.decoding_graph, map_location=device)
torch.load(params.decoding_graph, map_location=device, weights_only=False)
)
elif params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)

View File

@ -47,7 +47,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
logging.info("Loading G.fst.txt")
with open(lang_dir / "G.fst.txt") as f:

View File

@ -14,7 +14,7 @@ consisting of words and tokens (i.e., phones) and does the following:
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
d = torch.load("L.pt", weights_only=False)
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.

View File

@ -589,7 +589,7 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
)
assert HLG.requires_grad is False
@ -628,7 +628,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False)
G = k2.Fsa.from_dict(d)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:

View File

@ -668,7 +668,7 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
)
assert HLG.requires_grad is False
@ -707,7 +707,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False)
G = k2.Fsa.from_dict(d)
if params.decoding_method == "whole-lattice-rescoring":

View File

@ -1000,7 +1000,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -1001,7 +1001,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -183,7 +183,7 @@ def load_model_params(
"""
logging.info(f"Loading checkpoint from {ckpt}")
checkpoint = torch.load(ckpt, map_location="cpu")
checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False)
# if module list is empty, load the whole model from ckpt
if not init_modules:

View File

@ -938,7 +938,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -666,7 +666,7 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
)
assert HLG.requires_grad is False
@ -705,7 +705,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False)
G = k2.Fsa.from_dict(d)
if params.decoding_method == "whole-lattice-rescoring":

View File

@ -989,7 +989,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -177,7 +177,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -1286,7 +1286,7 @@ def run(rank, world_size, args):
logging.info(
f"Initializing model with checkpoint from {params.model_init_ckpt}"
)
init_ckpt = torch.load(params.model_init_ckpt, map_location=device)
init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False)
model.load_state_dict(init_ckpt["model"], strict=False)
if world_size > 1:

View File

@ -1175,7 +1175,7 @@ def run(rank, world_size, args):
logging.info(
f"Initializing model with checkpoint from {params.model_init_ckpt}"
)
init_ckpt = torch.load(params.model_init_ckpt, map_location=device)
init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False)
model.load_state_dict(init_ckpt["model"], strict=True)
if world_size > 1:

View File

@ -252,7 +252,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -960,7 +960,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -750,7 +750,7 @@ def _to_int_tuple(s: str):
def get_encoder_model(params: AttributeDict) -> nn.Module:
if hasattr(params, "pretrained_dir"):
logging.info(f"Loading {params.pretrained_dir}")
pretrained = torch.load(params.pretrained_dir)
pretrained = torch.load(params.pretrained_dir, weights_only=False)
encoder = HubertModel(params)
encoder.load_state_dict(pretrained["model"])
else:

View File

@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
d = torch.load("L.pt", weights_only=False)
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.

View File

@ -264,7 +264,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -234,7 +234,7 @@ def main():
logging.info("Creating model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -234,7 +234,7 @@ def main():
logging.info("Creating model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -234,7 +234,7 @@ def main():
logging.info("Creating model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -962,7 +962,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -962,7 +962,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -451,7 +451,7 @@ def _to_int_tuple(s: str):
def get_encoder_model(params: AttributeDict) -> nn.Module:
if hasattr(params, "pretrained_dir"):
logging.info(f"Loading {params.pretrained_dir}")
pretrained = torch.load(params.pretrained_dir)
pretrained = torch.load(params.pretrained_dir, weights_only=False)
encoder = HubertModel(params)
encoder.load_state_dict(pretrained["model"])
else:

View File

@ -451,7 +451,7 @@ def _to_int_tuple(s: str):
def get_encoder_model(params: AttributeDict) -> nn.Module:
if hasattr(params, "pretrained_dir"):
logging.info(f"Loading {params.pretrained_dir}")
pretrained = torch.load(params.pretrained_dir)
pretrained = torch.load(params.pretrained_dir, weights_only=False)
encoder = HubertModel(params)
encoder.load_state_dict(pretrained["model"])
else:

View File

@ -12,7 +12,7 @@ args = parser.parse_args()
src = args.src
tgt = args.tgt
old_checkpoint = torch.load(src)
old_checkpoint = torch.load(src, weights_only=False)
new_checkpoint = OrderedDict()
new_checkpoint["model"] = old_checkpoint["model"]
torch.save(new_checkpoint, tgt)

View File

@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
d = torch.load("L.pt", weights_only=False)
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.

View File

@ -960,7 +960,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -750,7 +750,7 @@ def _to_int_tuple(s: str):
def get_encoder_model(params: AttributeDict) -> nn.Module:
if hasattr(params, "pretrained_dir"):
logging.info(f"Loading {params.pretrained_dir}")
pretrained = torch.load(params.pretrained_dir)
pretrained = torch.load(params.pretrained_dir, weights_only=False)
encoder = HubertModel(params)
encoder.load_state_dict(pretrained["model"])
else:

View File

@ -578,7 +578,7 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
)
assert HLG.requires_grad is False

View File

@ -457,7 +457,7 @@ def main():
params.num_classes = num_classes
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False))
HLG = HLG.to(device)
assert HLG.requires_grad is False

View File

@ -78,11 +78,11 @@ def compile_HLG(lm_dir: str, lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
if Path(f"{lm_dir}/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"{lm_dir}/{lm}.pt")
d = torch.load(f"{lm_dir}/{lm}.pt", weights_only=False)
G = k2.Fsa.from_dict(d)
else:
logging.info(f"Loading {lm}.fst.txt")

View File

@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
d = torch.load("L.pt", weights_only=False)
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.

View File

@ -29,7 +29,7 @@ consisting of words and tokens (i.e., phones) and does the following:
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
d = torch.load("L.pt", weights_only=False)
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.

View File

@ -802,7 +802,7 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
)
assert HLG.requires_grad is False
@ -842,7 +842,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False)
G = k2.Fsa.from_dict(d)
if params.decoding_method in [

View File

@ -1014,7 +1014,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -41,7 +41,7 @@ def get_padding(kernel_size, dilation=1):
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print(f"Loading '{filepath}'")
checkpoint_dict = torch.load(filepath, map_location=device)
checkpoint_dict = torch.load(filepath, map_location=device, weights_only=False)
print("Complete.")
return checkpoint_dict

View File

@ -103,7 +103,7 @@ def load_vocoder(checkpoint_path: Path) -> nn.Module:
hifigan = HiFiGAN(h).to("cpu")
hifigan.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["generator"]
torch.load(checkpoint_path, map_location="cpu", weights_only=False)["generator"]
)
_ = hifigan.eval()
hifigan.remove_weight_norm()

View File

@ -756,7 +756,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -575,7 +575,7 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
)
assert HLG.requires_grad is False
@ -614,7 +614,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False)
G = k2.Fsa.from_dict(d)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:

View File

@ -275,7 +275,7 @@ def main():
use_feat_batchnorm=params.use_feat_batchnorm,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
@ -347,7 +347,7 @@ def main():
"attention-decoder",
]:
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@ -358,7 +358,7 @@ def main():
"attention-decoder",
]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu", weights_only=False))
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = G.to(device)

View File

@ -236,7 +236,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -733,7 +733,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -90,10 +90,10 @@ def average_checkpoints(
"""
n = len(filenames)
if "model" in torch.load(filenames[0], map_location=device):
avg = torch.load(filenames[0], map_location=device)["model"]
if "model" in torch.load(filenames[0], map_location=device, weights_only=False):
avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"]
else:
avg = torch.load(filenames[0], map_location=device)
avg = torch.load(filenames[0], map_location=device, weights_only=False)
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
@ -108,10 +108,10 @@ def average_checkpoints(
uniqued_names = list(uniqued.values())
for i in range(1, n):
if "model" in torch.load(filenames[i], map_location=device):
state_dict = torch.load(filenames[i], map_location=device)["model"]
if "model" in torch.load(filenames[i], map_location=device, weights_only=False):
state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"]
else:
state_dict = torch.load(filenames[i], map_location=device)
state_dict = torch.load(filenames[i], map_location=device, weights_only=False)
for k in uniqued_names:
avg[k] += state_dict[k]
@ -484,7 +484,7 @@ def main():
start = params.epoch - params.avg
assert start >= 1, start
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False
)
if "model" not in checkpoint:
# deepspeed converted checkpoint only contains model state_dict
@ -513,7 +513,7 @@ def main():
torch.save(model.state_dict(), filename)
else:
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False
)
if "model" not in checkpoint:
model.load_state_dict(checkpoint, strict=True)

View File

@ -809,7 +809,7 @@ def run(rank, world_size, args):
del model.alignment_heads
if params.pretrained_model_path:
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu", weights_only=False)
if "model" not in checkpoint:
model.load_state_dict(checkpoint, strict=True)
else:

View File

@ -784,7 +784,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -24,7 +24,7 @@ Usage:
--exp-dir ./zipformer/exp
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
You can later load it by `torch.load("epoch-28-avg-15.pt")`.
You can later load it by `torch.load("epoch-28-avg-15.pt", weights_only=False)`.
(2) use the checkpoint exp_dir/checkpoint-iter.pt
./zipformer/generate_averaged_model.py \
@ -33,7 +33,7 @@ You can later load it by `torch.load("epoch-28-avg-15.pt")`.
--exp-dir ./zipformer/exp
It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
You can later load it by `torch.load("iter-22000-avg-5.pt")`.
You can later load it by `torch.load("iter-22000-avg-5.pt", weights_only=False)`.
"""

View File

@ -291,7 +291,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -792,7 +792,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -24,7 +24,7 @@ Usage:
--exp-dir ./zipformer/exp
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
You can later load it by `torch.load("epoch-28-avg-15.pt")`.
You can later load it by `torch.load("epoch-28-avg-15.pt", weights_only=False)`.
(2) use the checkpoint exp_dir/checkpoint-iter.pt
./zipformer/generate_averaged_model.py \
@ -33,7 +33,7 @@ You can later load it by `torch.load("epoch-28-avg-15.pt")`.
--exp-dir ./zipformer/exp
It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
You can later load it by `torch.load("iter-22000-avg-5.pt")`.
You can later load it by `torch.load("iter-22000-avg-5.pt", weights_only=False)`.
"""

View File

@ -294,7 +294,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()

View File

@ -64,7 +64,7 @@ def main():
if out_lm_data.is_file():
logging.warning(f"{out_lm_data} exists - skipping")
return
data = torch.load(in_lm_data)
data = torch.load(in_lm_data, weights_only=False)
words2bpe = data["words"]
sentences = data["sentences"]
sentence_lengths = data["sentence_lengths"]

View File

@ -37,7 +37,7 @@ def main():
sp = spm.SentencePieceProcessor()
sp.load(str(bpe_model))
data = torch.load(lm_training_data)
data = torch.load(lm_training_data, weights_only=False)
words2bpe = data["words"]
sentences = data["sentences"]

View File

@ -1008,7 +1008,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -95,10 +95,10 @@ def average_checkpoints(
"""
n = len(filenames)
if "model" in torch.load(filenames[0], map_location=device):
avg = torch.load(filenames[0], map_location=device)["model"]
if "model" in torch.load(filenames[0], map_location=device, weights_only=False):
avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"]
else:
avg = torch.load(filenames[0], map_location=device)
avg = torch.load(filenames[0], map_location=device, weights_only=False)
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
@ -113,10 +113,10 @@ def average_checkpoints(
uniqued_names = list(uniqued.values())
for i in range(1, n):
if "model" in torch.load(filenames[i], map_location=device):
state_dict = torch.load(filenames[i], map_location=device)["model"]
if "model" in torch.load(filenames[i], map_location=device, weights_only=False):
state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"]
else:
state_dict = torch.load(filenames[i], map_location=device)
state_dict = torch.load(filenames[i], map_location=device, weights_only=False)
for k in uniqued_names:
avg[k] += state_dict[k]
@ -548,7 +548,7 @@ def main():
# torch.save(avg_checkpoint, filename)
else:
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin",
f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin", weights_only=False,
map_location="cpu",
)
model.load_state_dict(checkpoint, strict=False)

View File

@ -652,7 +652,7 @@ def run(rank, world_size, args):
)
if params.pretrained_model_path:
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu", weights_only=False)
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
num_param = sum([p.numel() for p in model.parameters()])
@ -704,7 +704,7 @@ def run(rank, world_size, args):
sampler_state_dict = None
if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path)
sampler_state_dict = torch.load(params.sampler_state_dict_path, weights_only=False)
sampler_state_dict["max_duration"] = params.max_duration
train_dl = data_module.train_dataloaders(

View File

@ -91,10 +91,10 @@ def average_checkpoints(
"""
n = len(filenames)
if "model" in torch.load(filenames[0], map_location=device):
avg = torch.load(filenames[0], map_location=device)["model"]
if "model" in torch.load(filenames[0], map_location=device, weights_only=False):
avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"]
else:
avg = torch.load(filenames[0], map_location=device)
avg = torch.load(filenames[0], map_location=device, weights_only=False)
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
@ -109,10 +109,10 @@ def average_checkpoints(
uniqued_names = list(uniqued.values())
for i in range(1, n):
if "model" in torch.load(filenames[i], map_location=device):
state_dict = torch.load(filenames[i], map_location=device)["model"]
if "model" in torch.load(filenames[i], map_location=device, weights_only=False):
state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"]
else:
state_dict = torch.load(filenames[i], map_location=device)
state_dict = torch.load(filenames[i], map_location=device, weights_only=False)
for k in uniqued_names:
avg[k] += state_dict[k]
@ -447,7 +447,7 @@ def main():
start = params.epoch - params.avg
assert start >= 1, start
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False
)
if "model" not in checkpoint:
# deepspeed converted checkpoint only contains model state_dict
@ -476,7 +476,7 @@ def main():
torch.save(model.state_dict(), filename)
else:
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False
)
if "model" not in checkpoint:
model.load_state_dict(checkpoint, strict=True)

Some files were not shown because too many files have changed in this diff Show More