Add cut_id to recognition results

This commit is contained in:
pkufool 2022-08-07 17:19:19 +08:00
parent dc6499a052
commit aa078cc6d7
2 changed files with 18 additions and 14 deletions

View File

@ -551,6 +551,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -564,9 +565,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -632,6 +633,8 @@ def main():
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
# we need cut ids to display recognition results.
args.return_cuts = True
params = get_params()
params.update(vars(args))

View File

@ -321,7 +321,7 @@ def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
def store_transcripts(
filename: Pathlike, texts: Iterable[Tuple[str, str]]
filename: Pathlike, texts: Iterable[Tuple[str, str, str]]
) -> None:
"""Save predicted results and reference transcripts to a file.
@ -329,15 +329,15 @@ def store_transcripts(
filename:
File to save the results to.
texts:
An iterable of tuples. The first element is the reference transcript
while the second element is the predicted result.
An iterable of tuples. The first element is the cur_id, the second is
the reference transcript and the third element is the predicted result.
Returns:
Return None.
"""
with open(filename, "w") as f:
for ref, hyp in texts:
print(f"ref={ref}", file=f)
print(f"hyp={hyp}", file=f)
for cut_id, ref, hyp in texts:
print(f"{cut_id}:\tref={ref}", file=f)
print(f"{cut_id}:\thyp={hyp}", file=f)
def write_error_stats(
@ -372,8 +372,8 @@ def write_error_stats(
The reference word `SIR` is missing in the predicted
results (a deletion error).
results:
An iterable of tuples. The first element is the reference transcript
while the second element is the predicted result.
An iterable of tuples. The first element is the cur_id, the second is
the reference transcript and the third element is the predicted result.
enable_log:
If True, also print detailed WER to the console.
Otherwise, it is written only to the given file.
@ -389,7 +389,7 @@ def write_error_stats(
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0
ERR = "*"
for ref, hyp in results:
for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR)
for ref_word, hyp_word in ali:
if ref_word == ERR:
@ -405,7 +405,7 @@ def write_error_stats(
else:
words[ref_word][0] += 1
num_corr += 1
ref_len = sum([len(r) for r, _ in results])
ref_len = sum([len(r) for _, r, _ in results])
sub_errs = sum(subs.values())
ins_errs = sum(ins.values())
del_errs = sum(dels.values())
@ -434,7 +434,7 @@ def write_error_stats(
print("", file=f)
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
for ref, hyp in results:
for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR)
combine_successive_errors = True
if combine_successive_errors:
@ -461,7 +461,8 @@ def write_error_stats(
]
print(
" ".join(
f"{cut_id}:\t"
+ " ".join(
(
ref_word
if ref_word == hyp_word