mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix computing WERs for empty hypotheses (#118)
* Fix computing WERs when empty lattices are generated. * Minor fixes.
This commit is contained in:
parent
336283f872
commit
0660d12e4e
@ -239,7 +239,7 @@ def decode_one_batch(
|
|||||||
is a 3-gram LM, while this G is a 4-gram LM.
|
is a 3-gram LM, while this G is a 4-gram LM.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict. Note: If it decodes to nothing, then return None.
|
||||||
"""
|
"""
|
||||||
if HLG is not None:
|
if HLG is not None:
|
||||||
device = HLG.device
|
device = HLG.device
|
||||||
@ -392,8 +392,7 @@ def decode_one_batch(
|
|||||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
ans[lm_scale_str] = hyps
|
ans[lm_scale_str] = hyps
|
||||||
else:
|
else:
|
||||||
for lm_scale in lm_scale_list:
|
ans = None
|
||||||
ans[f"{lm_scale}"] = [[] * lattice.shape[0]]
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
@ -467,16 +466,29 @@ def decode_dataset(
|
|||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
for lm_scale, hyps in hyps_dict.items():
|
if hyps_dict is not None:
|
||||||
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
|
this_batch = []
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
|
for hyp_words, ref_text in zip(hyps, texts):
|
||||||
|
ref_words = ref_text.split()
|
||||||
|
this_batch.append((ref_words, hyp_words))
|
||||||
|
|
||||||
|
results[lm_scale].extend(this_batch)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
len(results) > 0
|
||||||
|
), "It should not decode to empty in the first batch!"
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
hyp_words = []
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for ref_text in texts:
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((ref_words, hyp_words))
|
||||||
|
|
||||||
results[lm_scale].extend(this_batch)
|
for lm_scale in results.keys():
|
||||||
|
results[lm_scale].extend(this_batch)
|
||||||
|
|
||||||
num_cuts += len(batch["supervisions"]["text"])
|
num_cuts += len(texts)
|
||||||
|
|
||||||
if batch_idx % 100 == 0:
|
if batch_idx % 100 == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user