mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
Support computing nbest oracle WER.
This commit is contained in:
parent
1c3b13c7eb
commit
401c1c5143
@ -21,6 +21,7 @@ from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
|||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
get_lattice,
|
get_lattice,
|
||||||
nbest_decoding,
|
nbest_decoding,
|
||||||
|
nbest_oracle,
|
||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
rescore_with_attention_decoder,
|
rescore_with_attention_decoder,
|
||||||
rescore_with_n_best_list,
|
rescore_with_n_best_list,
|
||||||
@ -56,6 +57,15 @@ def get_parser():
|
|||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--scale",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="The scale to be applied to `lattice.scores`."
|
||||||
|
"A smaller value results in more unique paths",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -85,10 +95,12 @@ def get_params() -> AttributeDict:
|
|||||||
# - nbest-rescoring
|
# - nbest-rescoring
|
||||||
# - whole-lattice-rescoring
|
# - whole-lattice-rescoring
|
||||||
# - attention-decoder
|
# - attention-decoder
|
||||||
|
# - nbest-oracle
|
||||||
# "method": "whole-lattice-rescoring",
|
# "method": "whole-lattice-rescoring",
|
||||||
"method": "attention-decoder",
|
"method": "attention-decoder",
|
||||||
|
# "method": "nbest-oracle",
|
||||||
# num_paths is used when method is "nbest", "nbest-rescoring",
|
# num_paths is used when method is "nbest", "nbest-rescoring",
|
||||||
# and attention-decoder
|
# attention-decoder, and nbest-oracle
|
||||||
"num_paths": 100,
|
"num_paths": 100,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -179,6 +191,19 @@ def decode_one_batch(
|
|||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if params.method == "nbest-oracle":
|
||||||
|
# Note: You can also pass rescored lattices to it.
|
||||||
|
# We choose the HLG decoded lattice for speed reasons
|
||||||
|
# as HLG decoding is faster and the oracle WER
|
||||||
|
# is slightly worse than that of rescored lattices.
|
||||||
|
return nbest_oracle(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
ref_texts=supervisions["text"],
|
||||||
|
lexicon=lexicon,
|
||||||
|
scale=params.scale,
|
||||||
|
)
|
||||||
|
|
||||||
if params.method in ["1best", "nbest"]:
|
if params.method in ["1best", "nbest"]:
|
||||||
if params.method == "1best":
|
if params.method == "1best":
|
||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
@ -284,7 +309,6 @@ def decode_dataset(
|
|||||||
results = []
|
results = []
|
||||||
|
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
tot_num_cuts = len(dl.dataset.cuts)
|
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -315,8 +339,7 @@ def decode_dataset(
|
|||||||
if batch_idx % 100 == 0:
|
if batch_idx % 100 == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"batch {batch_idx}, cuts processed until now is "
|
f"batch {batch_idx}, cuts processed until now is "
|
||||||
f"{num_cuts}/{tot_num_cuts} "
|
f"{num_cuts} "
|
||||||
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
|
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -2,9 +2,12 @@ import logging
|
|||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
|
import kaldialign
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
|
||||||
|
|
||||||
def _intersect_device(
|
def _intersect_device(
|
||||||
a_fsas: k2.Fsa,
|
a_fsas: k2.Fsa,
|
||||||
@ -376,7 +379,7 @@ def rescore_with_n_best_list(
|
|||||||
#
|
#
|
||||||
# num_repeats is also a k2.RaggedInt with 2 axes containing the
|
# num_repeats is also a k2.RaggedInt with 2 axes containing the
|
||||||
# multiplicities of each path.
|
# multiplicities of each path.
|
||||||
# num_repeats.num_elements() == unique_word_seqs.num_elements()
|
# num_repeats.num_elements() == unique_word_seqs.tot_size(1)
|
||||||
#
|
#
|
||||||
# Since k2.ragged.unique_sequences will reorder paths within a seq,
|
# Since k2.ragged.unique_sequences will reorder paths within a seq,
|
||||||
# `new2old` is a 1-D torch.Tensor mapping from the output path index
|
# `new2old` is a 1-D torch.Tensor mapping from the output path index
|
||||||
@ -549,6 +552,76 @@ def rescore_with_whole_lattice(
|
|||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def nbest_oracle(
|
||||||
|
lattice: k2.Fsa,
|
||||||
|
num_paths: int,
|
||||||
|
ref_texts: List[str],
|
||||||
|
lexicon: Lexicon,
|
||||||
|
scale: float = 1.0,
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
"""Select the best hypothesis given a lattice and a reference transcript.
|
||||||
|
|
||||||
|
The basic idea is to extract n paths from the given lattice, unique them,
|
||||||
|
and select the one that has the minimum edit distance with the corresponding
|
||||||
|
reference transcript as the decoding output.
|
||||||
|
|
||||||
|
The decoding result returned from this function is the best result that
|
||||||
|
we can obtain using n-best decoding with all kinds of rescoring techniques.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lattice:
|
||||||
|
An FsaVec. It can be the return value of :func:`get_lattice`.
|
||||||
|
Note: We assume its aux_labels contain word IDs.
|
||||||
|
num_paths:
|
||||||
|
The size of `n` in n-best.
|
||||||
|
ref_texts:
|
||||||
|
A list of reference transcript. Each entry contains space(s)
|
||||||
|
separated words
|
||||||
|
lexicon:
|
||||||
|
It is used to convert word IDs to word symbols.
|
||||||
|
scale:
|
||||||
|
It's the scale applied to the lattice.scores. A smaller value
|
||||||
|
yields more unique paths.
|
||||||
|
Return:
|
||||||
|
Return a dict. Its key contains the information about the parameters
|
||||||
|
when calling this function, while its value contains the decoding output.
|
||||||
|
`len(ans_dict) == len(ref_texts)`
|
||||||
|
"""
|
||||||
|
saved_scores = lattice.scores.clone()
|
||||||
|
|
||||||
|
lattice.scores *= scale
|
||||||
|
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
|
||||||
|
lattice.scores = saved_scores
|
||||||
|
|
||||||
|
word_seq = k2.index(lattice.aux_labels, path)
|
||||||
|
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
|
||||||
|
unique_word_seq, _, _ = k2.ragged.unique_sequences(
|
||||||
|
word_seq, need_num_repeats=False, need_new2old_indexes=False
|
||||||
|
)
|
||||||
|
unique_word_ids = k2.ragged.to_list(unique_word_seq)
|
||||||
|
assert len(unique_word_ids) == len(ref_texts)
|
||||||
|
# unique_word_ids[i] contains all hypotheses of the i-th utterance
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for hyps, ref in zip(unique_word_ids, ref_texts):
|
||||||
|
# Note hyps is a list-of-list ints
|
||||||
|
# Each sublist contains a hypothesis
|
||||||
|
ref_words = ref.strip().split()
|
||||||
|
# CAUTION: We don't convert ref_words to ref_words_ids
|
||||||
|
# since there may exist OOV words in ref_words
|
||||||
|
best_hyp_words = None
|
||||||
|
min_error = float("inf")
|
||||||
|
for hyp_words in hyps:
|
||||||
|
hyp_words = [lexicon.word_table[i] for i in hyp_words]
|
||||||
|
this_error = kaldialign.edit_distance(ref_words, hyp_words)["total"]
|
||||||
|
if this_error < min_error:
|
||||||
|
min_error = this_error
|
||||||
|
best_hyp_words = hyp_words
|
||||||
|
results.append(best_hyp_words)
|
||||||
|
|
||||||
|
return {f"nbest_{num_paths}_scale_{scale}_oracle": results}
|
||||||
|
|
||||||
|
|
||||||
def rescore_with_attention_decoder(
|
def rescore_with_attention_decoder(
|
||||||
lattice: k2.Fsa,
|
lattice: k2.Fsa,
|
||||||
num_paths: int,
|
num_paths: int,
|
||||||
@ -605,7 +678,7 @@ def rescore_with_attention_decoder(
|
|||||||
#
|
#
|
||||||
# num_repeats is also a k2.RaggedInt with 2 axes containing the
|
# num_repeats is also a k2.RaggedInt with 2 axes containing the
|
||||||
# multiplicities of each path.
|
# multiplicities of each path.
|
||||||
# num_repeats.num_elements() == unique_word_seqs.num_elements()
|
# num_repeats.num_elements() == unique_word_seqs.tot_size(1)
|
||||||
#
|
#
|
||||||
# Since k2.ragged.unique_sequences will reorder paths within a seq,
|
# Since k2.ragged.unique_sequences will reorder paths within a seq,
|
||||||
# `new2old` is a 1-D torch.Tensor mapping from the output path index
|
# `new2old` is a 1-D torch.Tensor mapping from the output path index
|
||||||
|
Loading…
x
Reference in New Issue
Block a user