mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Fix an error in displaying decoding process.
This commit is contained in:
parent
1c3b13c7eb
commit
a87a39da8c
@ -284,7 +284,6 @@ def decode_dataset(
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
tot_num_cuts = len(dl.dataset.cuts)
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
@ -314,9 +313,7 @@ def decode_dataset(
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
logging.info(
|
||||
f"batch {batch_idx}, cuts processed until now is "
|
||||
f"{num_cuts}/{tot_num_cuts} "
|
||||
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
|
||||
f"batch {batch_idx}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
@ -399,7 +396,9 @@ def main():
|
||||
sos_id = graph_compiler.sos_id
|
||||
eos_id = graph_compiler.eos_id
|
||||
|
||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
||||
)
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
@ -430,7 +429,7 @@ def main():
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
||||
G = k2.Fsa.from_dict(d).to(device)
|
||||
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
|
@ -236,7 +236,6 @@ def decode_dataset(
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
tot_num_cuts = len(dl.dataset.cuts)
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
@ -264,9 +263,7 @@ def decode_dataset(
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
logging.info(
|
||||
f"batch {batch_idx}, cuts processed until now is "
|
||||
f"{num_cuts}/{tot_num_cuts} "
|
||||
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
|
||||
f"batch {batch_idx}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
@ -328,7 +325,9 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt"))
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load("data/lang_phone/HLG.pt", map_location="cpu")
|
||||
)
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
@ -355,7 +354,7 @@ def main():
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
||||
G = k2.Fsa.from_dict(d).to(device)
|
||||
|
||||
if params.method == "whole-lattice-rescoring":
|
||||
|
Loading…
x
Reference in New Issue
Block a user