mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 01:22:22 +00:00
add weights_only=False to torch.load (#1984)
This commit is contained in:
parent
89728dd4f8
commit
da87e7fc99
@ -41,7 +41,7 @@ To give you an idea of what ``tdnn/exp/pretrained.pt`` contains, we can use the
|
||||
.. code-block:: python3
|
||||
|
||||
>>> import torch
|
||||
>>> m = torch.load("tdnn/exp/pretrained.pt")
|
||||
>>> m = torch.load("tdnn/exp/pretrained.pt", weights_only=False)
|
||||
>>> list(m.keys())
|
||||
['model']
|
||||
>>> list(m["model"].keys())
|
||||
|
@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
4. Generate L.pt, in k2 format. It can be loaded by
|
||||
|
||||
d = torch.load("L.pt")
|
||||
d = torch.load("L.pt", weights_only=False)
|
||||
lexicon = k2.Fsa.from_dict(d)
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
|
@ -224,7 +224,7 @@ def main():
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -503,7 +503,7 @@ def main():
|
||||
else:
|
||||
H = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
|
@ -249,7 +249,7 @@ def main():
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
@ -315,7 +315,7 @@ def main():
|
||||
hyps = [[token_sym_table[i] for i in ids] for ids in token_ids]
|
||||
elif params.method in ["1best", "attention-decoder"]:
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False))
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
|
@ -516,7 +516,7 @@ def main():
|
||||
else:
|
||||
H = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
|
@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
4. Generate L.pt, in k2 format. It can be loaded by
|
||||
|
||||
d = torch.load("L.pt")
|
||||
d = torch.load("L.pt", weights_only=False)
|
||||
lexicon = k2.Fsa.from_dict(d)
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
|
@ -227,7 +227,7 @@ def main():
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -228,7 +228,7 @@ def main():
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -773,7 +773,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -237,7 +237,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -337,7 +337,7 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
|
||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False))
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
|
@ -139,13 +139,13 @@ def main():
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False))
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
|
@ -245,7 +245,7 @@ def main():
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -225,7 +225,7 @@ def main():
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -225,7 +225,7 @@ def main():
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -89,10 +89,10 @@ def average_checkpoints(
|
||||
"""
|
||||
n = len(filenames)
|
||||
|
||||
if "model" in torch.load(filenames[0], map_location=device):
|
||||
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||
if "model" in torch.load(filenames[0], map_location=device, weights_only=False):
|
||||
avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"]
|
||||
else:
|
||||
avg = torch.load(filenames[0], map_location=device)
|
||||
avg = torch.load(filenames[0], map_location=device, weights_only=False)
|
||||
|
||||
# Identify shared parameters. Two parameters are said to be shared
|
||||
# if they have the same data_ptr
|
||||
@ -107,10 +107,10 @@ def average_checkpoints(
|
||||
uniqued_names = list(uniqued.values())
|
||||
|
||||
for i in range(1, n):
|
||||
if "model" in torch.load(filenames[i], map_location=device):
|
||||
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||
if "model" in torch.load(filenames[i], map_location=device, weights_only=False):
|
||||
state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"]
|
||||
else:
|
||||
state_dict = torch.load(filenames[i], map_location=device)
|
||||
state_dict = torch.load(filenames[i], map_location=device, weights_only=False)
|
||||
for k in uniqued_names:
|
||||
avg[k] += state_dict[k]
|
||||
|
||||
@ -440,7 +440,7 @@ def main():
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False
|
||||
)
|
||||
if "model" not in checkpoint:
|
||||
# deepspeed converted checkpoint only contains model state_dict
|
||||
@ -469,7 +469,7 @@ def main():
|
||||
torch.save(model.state_dict(), filename)
|
||||
else:
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False
|
||||
)
|
||||
if "model" not in checkpoint:
|
||||
model.load_state_dict(checkpoint, strict=True)
|
||||
|
@ -761,7 +761,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -783,7 +783,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -298,7 +298,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -728,7 +728,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -226,7 +226,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
4. Generate L.pt, in k2 format. It can be loaded by
|
||||
|
||||
d = torch.load("L.pt")
|
||||
d = torch.load("L.pt", weights_only=False)
|
||||
lexicon = k2.Fsa.from_dict(d)
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
|
@ -238,7 +238,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
4. Generate L.pt, in k2 format. It can be loaded by
|
||||
|
||||
d = torch.load("L.pt")
|
||||
d = torch.load("L.pt", weights_only=False)
|
||||
lexicon = k2.Fsa.from_dict(d)
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
|
@ -224,7 +224,7 @@ def main():
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -672,7 +672,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -1263,7 +1263,7 @@ def run(rank, world_size, args):
|
||||
logging.info(
|
||||
f"Initializing model with checkpoint from {params.model_init_ckpt}"
|
||||
)
|
||||
init_ckpt = torch.load(params.model_init_ckpt, map_location=device)
|
||||
init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False)
|
||||
model.load_state_dict(init_ckpt["model"], strict=False)
|
||||
|
||||
if world_size > 1:
|
||||
|
@ -1254,7 +1254,7 @@ def run(rank, world_size, args):
|
||||
logging.info(
|
||||
f"Initializing model with checkpoint from {params.model_init_ckpt}"
|
||||
)
|
||||
init_ckpt = torch.load(params.model_init_ckpt, map_location=device)
|
||||
init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False)
|
||||
model.load_state_dict(init_ckpt["model"], strict=False)
|
||||
|
||||
if world_size > 1:
|
||||
|
@ -141,7 +141,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -115,7 +115,7 @@ def load_vocoder(checkpoint_path: Path) -> nn.Module:
|
||||
|
||||
hifigan = HiFiGAN(h).to("cpu")
|
||||
hifigan.load_state_dict(
|
||||
torch.load(checkpoint_path, map_location="cpu")["generator"]
|
||||
torch.load(checkpoint_path, map_location="cpu", weights_only=False)["generator"]
|
||||
)
|
||||
_ = hifigan.eval()
|
||||
hifigan.remove_weight_norm()
|
||||
|
@ -73,11 +73,11 @@ def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
max_token_id = max(lexicon.tokens)
|
||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
|
||||
|
||||
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
|
||||
logging.info(f"Loading pre-compiled {lm}")
|
||||
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
|
||||
d = torch.load(f"{lang_dir}/lm/{lm}.pt", weights_only=False)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info(f"Loading {lm}.fst.txt")
|
||||
|
@ -68,11 +68,11 @@ def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
An FSA representing LG.
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
|
||||
|
||||
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
|
||||
logging.info(f"Loading pre-compiled {lm}")
|
||||
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
|
||||
d = torch.load(f"{lang_dir}/lm/{lm}.pt", weights_only=False)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info(f"Loading {lm}.fst.txt")
|
||||
|
@ -910,7 +910,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -247,7 +247,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -767,7 +767,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -627,7 +627,7 @@ def load_model_params(
|
||||
|
||||
"""
|
||||
logging.info(f"Loading checkpoint from {ckpt}")
|
||||
checkpoint = torch.load(ckpt, map_location="cpu")
|
||||
checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False)
|
||||
|
||||
# if module list is empty, load the whole model from ckpt
|
||||
if not init_modules:
|
||||
|
@ -25,7 +25,7 @@ Usage:
|
||||
--exp-dir ./pruned_transducer_stateless7/exp
|
||||
|
||||
It will generate a file `epoch-28-avg-15-use-averaged-model.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`.
|
||||
You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt", weights_only=False)`.
|
||||
|
||||
(2) use the averaged model with checkpoint exp_dir/checkpoint-iter.pt
|
||||
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
|
||||
@ -35,7 +35,7 @@ You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`.
|
||||
--exp-dir ./pruned_transducer_stateless7/exp
|
||||
|
||||
It will generate a file `iter-22000-avg-5-use-averaged-model.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`.
|
||||
You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt", weights_only=False)`.
|
||||
|
||||
(3) use the original model with checkpoint exp_dir/epoch-xxx.pt
|
||||
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
|
||||
@ -45,7 +45,7 @@ You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`.
|
||||
--exp-dir ./pruned_transducer_stateless7/exp
|
||||
|
||||
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("epoch-28-avg-15.pt")`.
|
||||
You can later load it by `torch.load("epoch-28-avg-15.pt", weights_only=False)`.
|
||||
|
||||
(4) use the original model with checkpoint exp_dir/checkpoint-iter.pt
|
||||
./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
|
||||
@ -55,7 +55,7 @@ You can later load it by `torch.load("epoch-28-avg-15.pt")`.
|
||||
--exp-dir ./pruned_transducer_stateless7/exp
|
||||
|
||||
It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("iter-22000-avg-5.pt")`.
|
||||
You can later load it by `torch.load("iter-22000-avg-5.pt", weights_only=False)`.
|
||||
"""
|
||||
|
||||
|
||||
|
@ -987,7 +987,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -756,7 +756,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -791,7 +791,7 @@ def main():
|
||||
|
||||
if params.decoding_graph:
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(params.decoding_graph, map_location=device)
|
||||
torch.load(params.decoding_graph, map_location=device, weights_only=False)
|
||||
)
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
@ -800,7 +800,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -239,7 +239,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -561,7 +561,7 @@ def main():
|
||||
decoding_graph = None
|
||||
if params.decoding_graph:
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(params.decoding_graph, map_location=device)
|
||||
torch.load(params.decoding_graph, map_location=device, weights_only=False)
|
||||
)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
|
@ -47,7 +47,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
max_token_id = max(lexicon.tokens)
|
||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
|
||||
|
||||
logging.info("Loading G.fst.txt")
|
||||
with open(lang_dir / "G.fst.txt") as f:
|
||||
|
@ -14,7 +14,7 @@ consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
4. Generate L.pt, in k2 format. It can be loaded by
|
||||
|
||||
d = torch.load("L.pt")
|
||||
d = torch.load("L.pt", weights_only=False)
|
||||
lexicon = k2.Fsa.from_dict(d)
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
|
@ -589,7 +589,7 @@ def main():
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
@ -628,7 +628,7 @@ def main():
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
|
@ -668,7 +668,7 @@ def main():
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
@ -707,7 +707,7 @@ def main():
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.decoding_method == "whole-lattice-rescoring":
|
||||
|
@ -1000,7 +1000,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -1001,7 +1001,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -183,7 +183,7 @@ def load_model_params(
|
||||
|
||||
"""
|
||||
logging.info(f"Loading checkpoint from {ckpt}")
|
||||
checkpoint = torch.load(ckpt, map_location="cpu")
|
||||
checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False)
|
||||
|
||||
# if module list is empty, load the whole model from ckpt
|
||||
if not init_modules:
|
||||
|
@ -938,7 +938,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -666,7 +666,7 @@ def main():
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
@ -705,7 +705,7 @@ def main():
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.decoding_method == "whole-lattice-rescoring":
|
||||
|
@ -989,7 +989,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -177,7 +177,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -1286,7 +1286,7 @@ def run(rank, world_size, args):
|
||||
logging.info(
|
||||
f"Initializing model with checkpoint from {params.model_init_ckpt}"
|
||||
)
|
||||
init_ckpt = torch.load(params.model_init_ckpt, map_location=device)
|
||||
init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False)
|
||||
model.load_state_dict(init_ckpt["model"], strict=False)
|
||||
|
||||
if world_size > 1:
|
||||
|
@ -1175,7 +1175,7 @@ def run(rank, world_size, args):
|
||||
logging.info(
|
||||
f"Initializing model with checkpoint from {params.model_init_ckpt}"
|
||||
)
|
||||
init_ckpt = torch.load(params.model_init_ckpt, map_location=device)
|
||||
init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False)
|
||||
model.load_state_dict(init_ckpt["model"], strict=True)
|
||||
|
||||
if world_size > 1:
|
||||
|
@ -252,7 +252,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -960,7 +960,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -750,7 +750,7 @@ def _to_int_tuple(s: str):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
if hasattr(params, "pretrained_dir"):
|
||||
logging.info(f"Loading {params.pretrained_dir}")
|
||||
pretrained = torch.load(params.pretrained_dir)
|
||||
pretrained = torch.load(params.pretrained_dir, weights_only=False)
|
||||
encoder = HubertModel(params)
|
||||
encoder.load_state_dict(pretrained["model"])
|
||||
else:
|
||||
|
@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
4. Generate L.pt, in k2 format. It can be loaded by
|
||||
|
||||
d = torch.load("L.pt")
|
||||
d = torch.load("L.pt", weights_only=False)
|
||||
lexicon = k2.Fsa.from_dict(d)
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
|
@ -264,7 +264,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -234,7 +234,7 @@ def main():
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -234,7 +234,7 @@ def main():
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -234,7 +234,7 @@ def main():
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -962,7 +962,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -962,7 +962,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -451,7 +451,7 @@ def _to_int_tuple(s: str):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
if hasattr(params, "pretrained_dir"):
|
||||
logging.info(f"Loading {params.pretrained_dir}")
|
||||
pretrained = torch.load(params.pretrained_dir)
|
||||
pretrained = torch.load(params.pretrained_dir, weights_only=False)
|
||||
encoder = HubertModel(params)
|
||||
encoder.load_state_dict(pretrained["model"])
|
||||
else:
|
||||
|
@ -451,7 +451,7 @@ def _to_int_tuple(s: str):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
if hasattr(params, "pretrained_dir"):
|
||||
logging.info(f"Loading {params.pretrained_dir}")
|
||||
pretrained = torch.load(params.pretrained_dir)
|
||||
pretrained = torch.load(params.pretrained_dir, weights_only=False)
|
||||
encoder = HubertModel(params)
|
||||
encoder.load_state_dict(pretrained["model"])
|
||||
else:
|
||||
|
@ -12,7 +12,7 @@ args = parser.parse_args()
|
||||
src = args.src
|
||||
tgt = args.tgt
|
||||
|
||||
old_checkpoint = torch.load(src)
|
||||
old_checkpoint = torch.load(src, weights_only=False)
|
||||
new_checkpoint = OrderedDict()
|
||||
new_checkpoint["model"] = old_checkpoint["model"]
|
||||
torch.save(new_checkpoint, tgt)
|
||||
|
@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
4. Generate L.pt, in k2 format. It can be loaded by
|
||||
|
||||
d = torch.load("L.pt")
|
||||
d = torch.load("L.pt", weights_only=False)
|
||||
lexicon = k2.Fsa.from_dict(d)
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
|
@ -960,7 +960,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -750,7 +750,7 @@ def _to_int_tuple(s: str):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
if hasattr(params, "pretrained_dir"):
|
||||
logging.info(f"Loading {params.pretrained_dir}")
|
||||
pretrained = torch.load(params.pretrained_dir)
|
||||
pretrained = torch.load(params.pretrained_dir, weights_only=False)
|
||||
encoder = HubertModel(params)
|
||||
encoder.load_state_dict(pretrained["model"])
|
||||
else:
|
||||
|
@ -578,7 +578,7 @@ def main():
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
|
@ -457,7 +457,7 @@ def main():
|
||||
|
||||
params.num_classes = num_classes
|
||||
|
||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
|
||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False))
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
|
@ -78,11 +78,11 @@ def compile_HLG(lm_dir: str, lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
max_token_id = max(lexicon.tokens)
|
||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
|
||||
|
||||
if Path(f"{lm_dir}/{lm}.pt").is_file():
|
||||
logging.info(f"Loading pre-compiled {lm}")
|
||||
d = torch.load(f"{lm_dir}/{lm}.pt")
|
||||
d = torch.load(f"{lm_dir}/{lm}.pt", weights_only=False)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info(f"Loading {lm}.fst.txt")
|
||||
|
@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
4. Generate L.pt, in k2 format. It can be loaded by
|
||||
|
||||
d = torch.load("L.pt")
|
||||
d = torch.load("L.pt", weights_only=False)
|
||||
lexicon = k2.Fsa.from_dict(d)
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
|
@ -29,7 +29,7 @@ consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
4. Generate L.pt, in k2 format. It can be loaded by
|
||||
|
||||
d = torch.load("L.pt")
|
||||
d = torch.load("L.pt", weights_only=False)
|
||||
lexicon = k2.Fsa.from_dict(d)
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
|
@ -802,7 +802,7 @@ def main():
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
@ -842,7 +842,7 @@ def main():
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.decoding_method in [
|
||||
|
@ -1014,7 +1014,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -41,7 +41,7 @@ def get_padding(kernel_size, dilation=1):
|
||||
def load_checkpoint(filepath, device):
|
||||
assert os.path.isfile(filepath)
|
||||
print(f"Loading '{filepath}'")
|
||||
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||
checkpoint_dict = torch.load(filepath, map_location=device, weights_only=False)
|
||||
print("Complete.")
|
||||
return checkpoint_dict
|
||||
|
||||
|
@ -103,7 +103,7 @@ def load_vocoder(checkpoint_path: Path) -> nn.Module:
|
||||
|
||||
hifigan = HiFiGAN(h).to("cpu")
|
||||
hifigan.load_state_dict(
|
||||
torch.load(checkpoint_path, map_location="cpu")["generator"]
|
||||
torch.load(checkpoint_path, map_location="cpu", weights_only=False)["generator"]
|
||||
)
|
||||
_ = hifigan.eval()
|
||||
hifigan.remove_weight_norm()
|
||||
|
@ -756,7 +756,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -575,7 +575,7 @@ def main():
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
@ -614,7 +614,7 @@ def main():
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
|
@ -275,7 +275,7 @@ def main():
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
@ -347,7 +347,7 @@ def main():
|
||||
"attention-decoder",
|
||||
]:
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False))
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
@ -358,7 +358,7 @@ def main():
|
||||
"attention-decoder",
|
||||
]:
|
||||
logging.info(f"Loading G from {params.G}")
|
||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu", weights_only=False))
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = G.to(device)
|
||||
|
@ -236,7 +236,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -733,7 +733,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -90,10 +90,10 @@ def average_checkpoints(
|
||||
"""
|
||||
n = len(filenames)
|
||||
|
||||
if "model" in torch.load(filenames[0], map_location=device):
|
||||
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||
if "model" in torch.load(filenames[0], map_location=device, weights_only=False):
|
||||
avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"]
|
||||
else:
|
||||
avg = torch.load(filenames[0], map_location=device)
|
||||
avg = torch.load(filenames[0], map_location=device, weights_only=False)
|
||||
|
||||
# Identify shared parameters. Two parameters are said to be shared
|
||||
# if they have the same data_ptr
|
||||
@ -108,10 +108,10 @@ def average_checkpoints(
|
||||
uniqued_names = list(uniqued.values())
|
||||
|
||||
for i in range(1, n):
|
||||
if "model" in torch.load(filenames[i], map_location=device):
|
||||
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||
if "model" in torch.load(filenames[i], map_location=device, weights_only=False):
|
||||
state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"]
|
||||
else:
|
||||
state_dict = torch.load(filenames[i], map_location=device)
|
||||
state_dict = torch.load(filenames[i], map_location=device, weights_only=False)
|
||||
for k in uniqued_names:
|
||||
avg[k] += state_dict[k]
|
||||
|
||||
@ -484,7 +484,7 @@ def main():
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False
|
||||
)
|
||||
if "model" not in checkpoint:
|
||||
# deepspeed converted checkpoint only contains model state_dict
|
||||
@ -513,7 +513,7 @@ def main():
|
||||
torch.save(model.state_dict(), filename)
|
||||
else:
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False
|
||||
)
|
||||
if "model" not in checkpoint:
|
||||
model.load_state_dict(checkpoint, strict=True)
|
||||
|
@ -809,7 +809,7 @@ def run(rank, world_size, args):
|
||||
del model.alignment_heads
|
||||
|
||||
if params.pretrained_model_path:
|
||||
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu", weights_only=False)
|
||||
if "model" not in checkpoint:
|
||||
model.load_state_dict(checkpoint, strict=True)
|
||||
else:
|
||||
|
@ -784,7 +784,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -24,7 +24,7 @@ Usage:
|
||||
--exp-dir ./zipformer/exp
|
||||
|
||||
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("epoch-28-avg-15.pt")`.
|
||||
You can later load it by `torch.load("epoch-28-avg-15.pt", weights_only=False)`.
|
||||
|
||||
(2) use the checkpoint exp_dir/checkpoint-iter.pt
|
||||
./zipformer/generate_averaged_model.py \
|
||||
@ -33,7 +33,7 @@ You can later load it by `torch.load("epoch-28-avg-15.pt")`.
|
||||
--exp-dir ./zipformer/exp
|
||||
|
||||
It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("iter-22000-avg-5.pt")`.
|
||||
You can later load it by `torch.load("iter-22000-avg-5.pt", weights_only=False)`.
|
||||
"""
|
||||
|
||||
|
||||
|
@ -291,7 +291,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -792,7 +792,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -24,7 +24,7 @@ Usage:
|
||||
--exp-dir ./zipformer/exp
|
||||
|
||||
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("epoch-28-avg-15.pt")`.
|
||||
You can later load it by `torch.load("epoch-28-avg-15.pt", weights_only=False)`.
|
||||
|
||||
(2) use the checkpoint exp_dir/checkpoint-iter.pt
|
||||
./zipformer/generate_averaged_model.py \
|
||||
@ -33,7 +33,7 @@ You can later load it by `torch.load("epoch-28-avg-15.pt")`.
|
||||
--exp-dir ./zipformer/exp
|
||||
|
||||
It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("iter-22000-avg-5.pt")`.
|
||||
You can later load it by `torch.load("iter-22000-avg-5.pt", weights_only=False)`.
|
||||
"""
|
||||
|
||||
|
||||
|
@ -294,7 +294,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -64,7 +64,7 @@ def main():
|
||||
if out_lm_data.is_file():
|
||||
logging.warning(f"{out_lm_data} exists - skipping")
|
||||
return
|
||||
data = torch.load(in_lm_data)
|
||||
data = torch.load(in_lm_data, weights_only=False)
|
||||
words2bpe = data["words"]
|
||||
sentences = data["sentences"]
|
||||
sentence_lengths = data["sentence_lengths"]
|
||||
|
@ -37,7 +37,7 @@ def main():
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(str(bpe_model))
|
||||
|
||||
data = torch.load(lm_training_data)
|
||||
data = torch.load(lm_training_data, weights_only=False)
|
||||
words2bpe = data["words"]
|
||||
sentences = data["sentences"]
|
||||
|
||||
|
@ -1008,7 +1008,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -95,10 +95,10 @@ def average_checkpoints(
|
||||
"""
|
||||
n = len(filenames)
|
||||
|
||||
if "model" in torch.load(filenames[0], map_location=device):
|
||||
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||
if "model" in torch.load(filenames[0], map_location=device, weights_only=False):
|
||||
avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"]
|
||||
else:
|
||||
avg = torch.load(filenames[0], map_location=device)
|
||||
avg = torch.load(filenames[0], map_location=device, weights_only=False)
|
||||
|
||||
# Identify shared parameters. Two parameters are said to be shared
|
||||
# if they have the same data_ptr
|
||||
@ -113,10 +113,10 @@ def average_checkpoints(
|
||||
uniqued_names = list(uniqued.values())
|
||||
|
||||
for i in range(1, n):
|
||||
if "model" in torch.load(filenames[i], map_location=device):
|
||||
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||
if "model" in torch.load(filenames[i], map_location=device, weights_only=False):
|
||||
state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"]
|
||||
else:
|
||||
state_dict = torch.load(filenames[i], map_location=device)
|
||||
state_dict = torch.load(filenames[i], map_location=device, weights_only=False)
|
||||
for k in uniqued_names:
|
||||
avg[k] += state_dict[k]
|
||||
|
||||
@ -548,7 +548,7 @@ def main():
|
||||
# torch.save(avg_checkpoint, filename)
|
||||
else:
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin",
|
||||
f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin", weights_only=False,
|
||||
map_location="cpu",
|
||||
)
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
|
@ -652,7 +652,7 @@ def run(rank, world_size, args):
|
||||
)
|
||||
|
||||
if params.pretrained_model_path:
|
||||
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu", weights_only=False)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
@ -704,7 +704,7 @@ def run(rank, world_size, args):
|
||||
|
||||
sampler_state_dict = None
|
||||
if params.sampler_state_dict_path:
|
||||
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
||||
sampler_state_dict = torch.load(params.sampler_state_dict_path, weights_only=False)
|
||||
sampler_state_dict["max_duration"] = params.max_duration
|
||||
|
||||
train_dl = data_module.train_dataloaders(
|
||||
|
@ -91,10 +91,10 @@ def average_checkpoints(
|
||||
"""
|
||||
n = len(filenames)
|
||||
|
||||
if "model" in torch.load(filenames[0], map_location=device):
|
||||
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||
if "model" in torch.load(filenames[0], map_location=device, weights_only=False):
|
||||
avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"]
|
||||
else:
|
||||
avg = torch.load(filenames[0], map_location=device)
|
||||
avg = torch.load(filenames[0], map_location=device, weights_only=False)
|
||||
|
||||
# Identify shared parameters. Two parameters are said to be shared
|
||||
# if they have the same data_ptr
|
||||
@ -109,10 +109,10 @@ def average_checkpoints(
|
||||
uniqued_names = list(uniqued.values())
|
||||
|
||||
for i in range(1, n):
|
||||
if "model" in torch.load(filenames[i], map_location=device):
|
||||
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||
if "model" in torch.load(filenames[i], map_location=device, weights_only=False):
|
||||
state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"]
|
||||
else:
|
||||
state_dict = torch.load(filenames[i], map_location=device)
|
||||
state_dict = torch.load(filenames[i], map_location=device, weights_only=False)
|
||||
for k in uniqued_names:
|
||||
avg[k] += state_dict[k]
|
||||
|
||||
@ -447,7 +447,7 @@ def main():
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False
|
||||
)
|
||||
if "model" not in checkpoint:
|
||||
# deepspeed converted checkpoint only contains model state_dict
|
||||
@ -476,7 +476,7 @@ def main():
|
||||
torch.save(model.state_dict(), filename)
|
||||
else:
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False
|
||||
)
|
||||
if "model" not in checkpoint:
|
||||
model.load_state_dict(checkpoint, strict=True)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user