mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Add cut_id to recognition results
This commit is contained in:
parent
dc6499a052
commit
aa078cc6d7
@ -551,6 +551,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -564,9 +565,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -632,6 +633,8 @@ def main():
|
|||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
@ -321,7 +321,7 @@ def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
|
|||||||
|
|
||||||
|
|
||||||
def store_transcripts(
|
def store_transcripts(
|
||||||
filename: Pathlike, texts: Iterable[Tuple[str, str]]
|
filename: Pathlike, texts: Iterable[Tuple[str, str, str]]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save predicted results and reference transcripts to a file.
|
"""Save predicted results and reference transcripts to a file.
|
||||||
|
|
||||||
@ -329,15 +329,15 @@ def store_transcripts(
|
|||||||
filename:
|
filename:
|
||||||
File to save the results to.
|
File to save the results to.
|
||||||
texts:
|
texts:
|
||||||
An iterable of tuples. The first element is the reference transcript
|
An iterable of tuples. The first element is the cur_id, the second is
|
||||||
while the second element is the predicted result.
|
the reference transcript and the third element is the predicted result.
|
||||||
Returns:
|
Returns:
|
||||||
Return None.
|
Return None.
|
||||||
"""
|
"""
|
||||||
with open(filename, "w") as f:
|
with open(filename, "w") as f:
|
||||||
for ref, hyp in texts:
|
for cut_id, ref, hyp in texts:
|
||||||
print(f"ref={ref}", file=f)
|
print(f"{cut_id}:\tref={ref}", file=f)
|
||||||
print(f"hyp={hyp}", file=f)
|
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||||
|
|
||||||
|
|
||||||
def write_error_stats(
|
def write_error_stats(
|
||||||
@ -372,8 +372,8 @@ def write_error_stats(
|
|||||||
The reference word `SIR` is missing in the predicted
|
The reference word `SIR` is missing in the predicted
|
||||||
results (a deletion error).
|
results (a deletion error).
|
||||||
results:
|
results:
|
||||||
An iterable of tuples. The first element is the reference transcript
|
An iterable of tuples. The first element is the cur_id, the second is
|
||||||
while the second element is the predicted result.
|
the reference transcript and the third element is the predicted result.
|
||||||
enable_log:
|
enable_log:
|
||||||
If True, also print detailed WER to the console.
|
If True, also print detailed WER to the console.
|
||||||
Otherwise, it is written only to the given file.
|
Otherwise, it is written only to the given file.
|
||||||
@ -389,7 +389,7 @@ def write_error_stats(
|
|||||||
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||||
num_corr = 0
|
num_corr = 0
|
||||||
ERR = "*"
|
ERR = "*"
|
||||||
for ref, hyp in results:
|
for cut_id, ref, hyp in results:
|
||||||
ali = kaldialign.align(ref, hyp, ERR)
|
ali = kaldialign.align(ref, hyp, ERR)
|
||||||
for ref_word, hyp_word in ali:
|
for ref_word, hyp_word in ali:
|
||||||
if ref_word == ERR:
|
if ref_word == ERR:
|
||||||
@ -405,7 +405,7 @@ def write_error_stats(
|
|||||||
else:
|
else:
|
||||||
words[ref_word][0] += 1
|
words[ref_word][0] += 1
|
||||||
num_corr += 1
|
num_corr += 1
|
||||||
ref_len = sum([len(r) for r, _ in results])
|
ref_len = sum([len(r) for _, r, _ in results])
|
||||||
sub_errs = sum(subs.values())
|
sub_errs = sum(subs.values())
|
||||||
ins_errs = sum(ins.values())
|
ins_errs = sum(ins.values())
|
||||||
del_errs = sum(dels.values())
|
del_errs = sum(dels.values())
|
||||||
@ -434,7 +434,7 @@ def write_error_stats(
|
|||||||
|
|
||||||
print("", file=f)
|
print("", file=f)
|
||||||
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||||
for ref, hyp in results:
|
for cut_id, ref, hyp in results:
|
||||||
ali = kaldialign.align(ref, hyp, ERR)
|
ali = kaldialign.align(ref, hyp, ERR)
|
||||||
combine_successive_errors = True
|
combine_successive_errors = True
|
||||||
if combine_successive_errors:
|
if combine_successive_errors:
|
||||||
@ -461,7 +461,8 @@ def write_error_stats(
|
|||||||
]
|
]
|
||||||
|
|
||||||
print(
|
print(
|
||||||
" ".join(
|
f"{cut_id}:\t"
|
||||||
|
+ " ".join(
|
||||||
(
|
(
|
||||||
ref_word
|
ref_word
|
||||||
if ref_word == hyp_word
|
if ref_word == hyp_word
|
||||||
|
Loading…
x
Reference in New Issue
Block a user