Fix computing WERs for empty hypotheses (#118)

* Fix computing WERs when empty lattices are generated.

* Minor fixes.
This commit is contained in:
Fangjun Kuang 2021-11-17 19:25:47 +08:00 committed by GitHub
parent 336283f872
commit 0660d12e4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -239,7 +239,7 @@ def decode_one_batch(
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
the returned dict. Note: If it decodes to nothing, then return None.
"""
if HLG is not None:
device = HLG.device
@ -392,8 +392,7 @@ def decode_one_batch(
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
else:
for lm_scale in lm_scale_list:
ans[f"{lm_scale}"] = [[] * lattice.shape[0]]
ans = None
return ans
@ -467,16 +466,29 @@ def decode_dataset(
eos_id=eos_id,
)
for lm_scale, hyps in hyps_dict.items():
if hyps_dict is not None:
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[lm_scale].extend(this_batch)
else:
assert (
len(results) > 0
), "It should not decode to empty in the first batch!"
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
hyp_words = []
for ref_text in texts:
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[lm_scale].extend(this_batch)
for lm_scale in results.keys():
results[lm_scale].extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
num_cuts += len(texts)
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"