mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
Support computing nbest oracle WER.
This commit is contained in:
parent
1c3b13c7eb
commit
401c1c5143
@ -21,6 +21,7 @@ from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_n_best_list,
|
||||
@ -56,6 +57,15 @@ def get_parser():
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The scale to be applied to `lattice.scores`."
|
||||
"A smaller value results in more unique paths",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -85,10 +95,12 @@ def get_params() -> AttributeDict:
|
||||
# - nbest-rescoring
|
||||
# - whole-lattice-rescoring
|
||||
# - attention-decoder
|
||||
# - nbest-oracle
|
||||
# "method": "whole-lattice-rescoring",
|
||||
"method": "attention-decoder",
|
||||
# "method": "nbest-oracle",
|
||||
# num_paths is used when method is "nbest", "nbest-rescoring",
|
||||
# and attention-decoder
|
||||
# attention-decoder, and nbest-oracle
|
||||
"num_paths": 100,
|
||||
}
|
||||
)
|
||||
@ -179,6 +191,19 @@ def decode_one_batch(
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.method == "nbest-oracle":
|
||||
# Note: You can also pass rescored lattices to it.
|
||||
# We choose the HLG decoded lattice for speed reasons
|
||||
# as HLG decoding is faster and the oracle WER
|
||||
# is slightly worse than that of rescored lattices.
|
||||
return nbest_oracle(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
lexicon=lexicon,
|
||||
scale=params.scale,
|
||||
)
|
||||
|
||||
if params.method in ["1best", "nbest"]:
|
||||
if params.method == "1best":
|
||||
best_path = one_best_decoding(
|
||||
@ -284,7 +309,6 @@ def decode_dataset(
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
tot_num_cuts = len(dl.dataset.cuts)
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
@ -315,8 +339,7 @@ def decode_dataset(
|
||||
if batch_idx % 100 == 0:
|
||||
logging.info(
|
||||
f"batch {batch_idx}, cuts processed until now is "
|
||||
f"{num_cuts}/{tot_num_cuts} "
|
||||
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
|
||||
f"{num_cuts} "
|
||||
)
|
||||
return results
|
||||
|
||||
|
@ -2,9 +2,12 @@ import logging
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import kaldialign
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def _intersect_device(
|
||||
a_fsas: k2.Fsa,
|
||||
@ -376,7 +379,7 @@ def rescore_with_n_best_list(
|
||||
#
|
||||
# num_repeats is also a k2.RaggedInt with 2 axes containing the
|
||||
# multiplicities of each path.
|
||||
# num_repeats.num_elements() == unique_word_seqs.num_elements()
|
||||
# num_repeats.num_elements() == unique_word_seqs.tot_size(1)
|
||||
#
|
||||
# Since k2.ragged.unique_sequences will reorder paths within a seq,
|
||||
# `new2old` is a 1-D torch.Tensor mapping from the output path index
|
||||
@ -549,6 +552,76 @@ def rescore_with_whole_lattice(
|
||||
return ans
|
||||
|
||||
|
||||
def nbest_oracle(
|
||||
lattice: k2.Fsa,
|
||||
num_paths: int,
|
||||
ref_texts: List[str],
|
||||
lexicon: Lexicon,
|
||||
scale: float = 1.0,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""Select the best hypothesis given a lattice and a reference transcript.
|
||||
|
||||
The basic idea is to extract n paths from the given lattice, unique them,
|
||||
and select the one that has the minimum edit distance with the corresponding
|
||||
reference transcript as the decoding output.
|
||||
|
||||
The decoding result returned from this function is the best result that
|
||||
we can obtain using n-best decoding with all kinds of rescoring techniques.
|
||||
|
||||
Args:
|
||||
lattice:
|
||||
An FsaVec. It can be the return value of :func:`get_lattice`.
|
||||
Note: We assume its aux_labels contain word IDs.
|
||||
num_paths:
|
||||
The size of `n` in n-best.
|
||||
ref_texts:
|
||||
A list of reference transcript. Each entry contains space(s)
|
||||
separated words
|
||||
lexicon:
|
||||
It is used to convert word IDs to word symbols.
|
||||
scale:
|
||||
It's the scale applied to the lattice.scores. A smaller value
|
||||
yields more unique paths.
|
||||
Return:
|
||||
Return a dict. Its key contains the information about the parameters
|
||||
when calling this function, while its value contains the decoding output.
|
||||
`len(ans_dict) == len(ref_texts)`
|
||||
"""
|
||||
saved_scores = lattice.scores.clone()
|
||||
|
||||
lattice.scores *= scale
|
||||
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
|
||||
lattice.scores = saved_scores
|
||||
|
||||
word_seq = k2.index(lattice.aux_labels, path)
|
||||
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
|
||||
unique_word_seq, _, _ = k2.ragged.unique_sequences(
|
||||
word_seq, need_num_repeats=False, need_new2old_indexes=False
|
||||
)
|
||||
unique_word_ids = k2.ragged.to_list(unique_word_seq)
|
||||
assert len(unique_word_ids) == len(ref_texts)
|
||||
# unique_word_ids[i] contains all hypotheses of the i-th utterance
|
||||
|
||||
results = []
|
||||
for hyps, ref in zip(unique_word_ids, ref_texts):
|
||||
# Note hyps is a list-of-list ints
|
||||
# Each sublist contains a hypothesis
|
||||
ref_words = ref.strip().split()
|
||||
# CAUTION: We don't convert ref_words to ref_words_ids
|
||||
# since there may exist OOV words in ref_words
|
||||
best_hyp_words = None
|
||||
min_error = float("inf")
|
||||
for hyp_words in hyps:
|
||||
hyp_words = [lexicon.word_table[i] for i in hyp_words]
|
||||
this_error = kaldialign.edit_distance(ref_words, hyp_words)["total"]
|
||||
if this_error < min_error:
|
||||
min_error = this_error
|
||||
best_hyp_words = hyp_words
|
||||
results.append(best_hyp_words)
|
||||
|
||||
return {f"nbest_{num_paths}_scale_{scale}_oracle": results}
|
||||
|
||||
|
||||
def rescore_with_attention_decoder(
|
||||
lattice: k2.Fsa,
|
||||
num_paths: int,
|
||||
@ -605,7 +678,7 @@ def rescore_with_attention_decoder(
|
||||
#
|
||||
# num_repeats is also a k2.RaggedInt with 2 axes containing the
|
||||
# multiplicities of each path.
|
||||
# num_repeats.num_elements() == unique_word_seqs.num_elements()
|
||||
# num_repeats.num_elements() == unique_word_seqs.tot_size(1)
|
||||
#
|
||||
# Since k2.ragged.unique_sequences will reorder paths within a seq,
|
||||
# `new2old` is a 1-D torch.Tensor mapping from the output path index
|
||||
|
Loading…
x
Reference in New Issue
Block a user