Merge remote-tracking branch 'dan/master' into nbest-oracle

This commit is contained in:
Fangjun Kuang 2021-08-19 16:26:23 +08:00
commit f841581fff
2 changed files with 10 additions and 10 deletions

View File

@ -346,8 +346,7 @@ def decode_dataset(
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
logging.info( logging.info(
f"batch {batch_idx}, cuts processed until now is " f"batch {batch_idx}, cuts processed until now is {num_cuts}"
f"{num_cuts} "
) )
return results return results
@ -430,7 +429,9 @@ def main():
sos_id = graph_compiler.sos_id sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id eos_id = graph_compiler.eos_id
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt")) HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device) HLG = HLG.to(device)
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -461,7 +462,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else: else:
logging.info("Loading pre-compiled G_4_gram.pt") logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt") d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device) G = k2.Fsa.from_dict(d).to(device)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]: if params.method in ["whole-lattice-rescoring", "attention-decoder"]:

View File

@ -236,7 +236,6 @@ def decode_dataset(
results = [] results = []
num_cuts = 0 num_cuts = 0
tot_num_cuts = len(dl.dataset.cuts)
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
@ -264,9 +263,7 @@ def decode_dataset(
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
logging.info( logging.info(
f"batch {batch_idx}, cuts processed until now is " f"batch {batch_idx}, cuts processed until now is {num_cuts}"
f"{num_cuts}/{tot_num_cuts} "
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
) )
return results return results
@ -328,7 +325,9 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt")) HLG = k2.Fsa.from_dict(
torch.load("data/lang_phone/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device) HLG = HLG.to(device)
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -355,7 +354,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else: else:
logging.info("Loading pre-compiled G_4_gram.pt") logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt") d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device) G = k2.Fsa.from_dict(d).to(device)
if params.method == "whole-lattice-rescoring": if params.method == "whole-lattice-rescoring":