Add cut_id to recognition results

This commit is contained in:
pkufool 2022-08-07 17:19:19 +08:00
parent dc6499a052
commit aa078cc6d7
2 changed files with 18 additions and 14 deletions

View File

@ -551,6 +551,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -564,9 +565,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -632,6 +633,8 @@ def main():
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
# we need cut ids to display recognition results.
args.return_cuts = True
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))

View File

@ -321,7 +321,7 @@ def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
def store_transcripts( def store_transcripts(
filename: Pathlike, texts: Iterable[Tuple[str, str]] filename: Pathlike, texts: Iterable[Tuple[str, str, str]]
) -> None: ) -> None:
"""Save predicted results and reference transcripts to a file. """Save predicted results and reference transcripts to a file.
@ -329,15 +329,15 @@ def store_transcripts(
filename: filename:
File to save the results to. File to save the results to.
texts: texts:
An iterable of tuples. The first element is the reference transcript An iterable of tuples. The first element is the cur_id, the second is
while the second element is the predicted result. the reference transcript and the third element is the predicted result.
Returns: Returns:
Return None. Return None.
""" """
with open(filename, "w") as f: with open(filename, "w") as f:
for ref, hyp in texts: for cut_id, ref, hyp in texts:
print(f"ref={ref}", file=f) print(f"{cut_id}:\tref={ref}", file=f)
print(f"hyp={hyp}", file=f) print(f"{cut_id}:\thyp={hyp}", file=f)
def write_error_stats( def write_error_stats(
@ -372,8 +372,8 @@ def write_error_stats(
The reference word `SIR` is missing in the predicted The reference word `SIR` is missing in the predicted
results (a deletion error). results (a deletion error).
results: results:
An iterable of tuples. The first element is the reference transcript An iterable of tuples. The first element is the cur_id, the second is
while the second element is the predicted result. the reference transcript and the third element is the predicted result.
enable_log: enable_log:
If True, also print detailed WER to the console. If True, also print detailed WER to the console.
Otherwise, it is written only to the given file. Otherwise, it is written only to the given file.
@ -389,7 +389,7 @@ def write_error_stats(
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0 num_corr = 0
ERR = "*" ERR = "*"
for ref, hyp in results: for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR) ali = kaldialign.align(ref, hyp, ERR)
for ref_word, hyp_word in ali: for ref_word, hyp_word in ali:
if ref_word == ERR: if ref_word == ERR:
@ -405,7 +405,7 @@ def write_error_stats(
else: else:
words[ref_word][0] += 1 words[ref_word][0] += 1
num_corr += 1 num_corr += 1
ref_len = sum([len(r) for r, _ in results]) ref_len = sum([len(r) for _, r, _ in results])
sub_errs = sum(subs.values()) sub_errs = sum(subs.values())
ins_errs = sum(ins.values()) ins_errs = sum(ins.values())
del_errs = sum(dels.values()) del_errs = sum(dels.values())
@ -434,7 +434,7 @@ def write_error_stats(
print("", file=f) print("", file=f)
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
for ref, hyp in results: for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR) ali = kaldialign.align(ref, hyp, ERR)
combine_successive_errors = True combine_successive_errors = True
if combine_successive_errors: if combine_successive_errors:
@ -461,7 +461,8 @@ def write_error_stats(
] ]
print( print(
" ".join( f"{cut_id}:\t"
+ " ".join(
( (
ref_word ref_word
if ref_word == hyp_word if ref_word == hyp_word