mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix an error in displaying decoding process. (#12)
This commit is contained in:
parent
1c3b13c7eb
commit
caa0b9e942
@ -284,7 +284,6 @@ def decode_dataset(
|
|||||||
results = []
|
results = []
|
||||||
|
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
tot_num_cuts = len(dl.dataset.cuts)
|
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -314,9 +313,7 @@ def decode_dataset(
|
|||||||
|
|
||||||
if batch_idx % 100 == 0:
|
if batch_idx % 100 == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"batch {batch_idx}, cuts processed until now is "
|
f"batch {batch_idx}, cuts processed until now is {num_cuts}"
|
||||||
f"{num_cuts}/{tot_num_cuts} "
|
|
||||||
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
|
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -399,7 +396,9 @@ def main():
|
|||||||
sos_id = graph_compiler.sos_id
|
sos_id = graph_compiler.sos_id
|
||||||
eos_id = graph_compiler.eos_id
|
eos_id = graph_compiler.eos_id
|
||||||
|
|
||||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
|
HLG = k2.Fsa.from_dict(
|
||||||
|
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
||||||
|
)
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
assert HLG.requires_grad is False
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
@ -430,7 +429,7 @@ def main():
|
|||||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||||
else:
|
else:
|
||||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||||
d = torch.load(params.lm_dir / "G_4_gram.pt")
|
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
||||||
G = k2.Fsa.from_dict(d).to(device)
|
G = k2.Fsa.from_dict(d).to(device)
|
||||||
|
|
||||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||||
|
@ -236,7 +236,6 @@ def decode_dataset(
|
|||||||
results = []
|
results = []
|
||||||
|
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
tot_num_cuts = len(dl.dataset.cuts)
|
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -264,9 +263,7 @@ def decode_dataset(
|
|||||||
|
|
||||||
if batch_idx % 100 == 0:
|
if batch_idx % 100 == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"batch {batch_idx}, cuts processed until now is "
|
f"batch {batch_idx}, cuts processed until now is {num_cuts}"
|
||||||
f"{num_cuts}/{tot_num_cuts} "
|
|
||||||
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
|
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -328,7 +325,9 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt"))
|
HLG = k2.Fsa.from_dict(
|
||||||
|
torch.load("data/lang_phone/HLG.pt", map_location="cpu")
|
||||||
|
)
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
assert HLG.requires_grad is False
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
@ -355,7 +354,7 @@ def main():
|
|||||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||||
else:
|
else:
|
||||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||||
d = torch.load(params.lm_dir / "G_4_gram.pt")
|
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
||||||
G = k2.Fsa.from_dict(d).to(device)
|
G = k2.Fsa.from_dict(d).to(device)
|
||||||
|
|
||||||
if params.method == "whole-lattice-rescoring":
|
if params.method == "whole-lattice-rescoring":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user