Fix an error in displaying decoding process. (#12)

This commit is contained in:
Fangjun Kuang 2021-08-19 14:54:01 +08:00 committed by GitHub
parent 1c3b13c7eb
commit caa0b9e942
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 12 deletions

View File

@ -284,7 +284,6 @@ def decode_dataset(
results = []
num_cuts = 0
tot_num_cuts = len(dl.dataset.cuts)
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -314,9 +313,7 @@ def decode_dataset(
if batch_idx % 100 == 0:
logging.info(
f"batch {batch_idx}, cuts processed until now is "
f"{num_cuts}/{tot_num_cuts} "
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
f"batch {batch_idx}, cuts processed until now is {num_cuts}"
)
return results
@ -399,7 +396,9 @@ def main():
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
@ -430,7 +429,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:

View File

@ -236,7 +236,6 @@ def decode_dataset(
results = []
num_cuts = 0
tot_num_cuts = len(dl.dataset.cuts)
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -264,9 +263,7 @@ def decode_dataset(
if batch_idx % 100 == 0:
logging.info(
f"batch {batch_idx}, cuts processed until now is "
f"{num_cuts}/{tot_num_cuts} "
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
f"batch {batch_idx}, cuts processed until now is {num_cuts}"
)
return results
@ -328,7 +325,9 @@ def main():
logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt"))
HLG = k2.Fsa.from_dict(
torch.load("data/lang_phone/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
@ -355,7 +354,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device)
if params.method == "whole-lattice-rescoring":