Fix computing WERs for empty hypotheses (#118)

* Fix computing WERs when empty lattices are generated.

* Minor fixes.
This commit is contained in:
Fangjun Kuang 2021-11-17 19:25:47 +08:00 committed by GitHub
parent 336283f872
commit 0660d12e4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -239,7 +239,7 @@ def decode_one_batch(
is a 3-gram LM, while this G is a 4-gram LM. is a 3-gram LM, while this G is a 4-gram LM.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict. Note: If it decodes to nothing, then return None.
""" """
if HLG is not None: if HLG is not None:
device = HLG.device device = HLG.device
@ -392,8 +392,7 @@ def decode_one_batch(
hyps = [[word_table[i] for i in ids] for ids in hyps] hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps ans[lm_scale_str] = hyps
else: else:
for lm_scale in lm_scale_list: ans = None
ans[f"{lm_scale}"] = [[] * lattice.shape[0]]
return ans return ans
@ -467,16 +466,29 @@ def decode_dataset(
eos_id=eos_id, eos_id=eos_id,
) )
for lm_scale, hyps in hyps_dict.items(): if hyps_dict is not None:
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[lm_scale].extend(this_batch)
else:
assert (
len(results) > 0
), "It should not decode to empty in the first batch!"
this_batch = [] this_batch = []
assert len(hyps) == len(texts) hyp_words = []
for hyp_words, ref_text in zip(hyps, texts): for ref_text in texts:
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((ref_words, hyp_words))
results[lm_scale].extend(this_batch) for lm_scale in results.keys():
results[lm_scale].extend(this_batch)
num_cuts += len(batch["supervisions"]["text"]) num_cuts += len(texts)
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"