Support computing nbest oracle WER.

This commit is contained in:
Fangjun Kuang 2021-08-18 12:54:01 +08:00
parent 1c3b13c7eb
commit 401c1c5143
2 changed files with 102 additions and 6 deletions

View File

@ -21,6 +21,7 @@ from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
@ -56,6 +57,15 @@ def get_parser():
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--scale",
type=float,
default=1.0,
help="The scale to be applied to `lattice.scores`."
"A smaller value results in more unique paths",
)
return parser
@ -85,10 +95,12 @@ def get_params() -> AttributeDict:
# - nbest-rescoring
# - whole-lattice-rescoring
# - attention-decoder
# - nbest-oracle
# "method": "whole-lattice-rescoring",
"method": "attention-decoder",
# "method": "nbest-oracle",
# num_paths is used when method is "nbest", "nbest-rescoring",
# and attention-decoder
# attention-decoder, and nbest-oracle
"num_paths": 100,
}
)
@ -179,6 +191,19 @@ def decode_one_batch(
subsampling_factor=params.subsampling_factor,
)
if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is slightly worse than that of rescored lattices.
return nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
lexicon=lexicon,
scale=params.scale,
)
if params.method in ["1best", "nbest"]:
if params.method == "1best":
best_path = one_best_decoding(
@ -284,7 +309,6 @@ def decode_dataset(
results = []
num_cuts = 0
tot_num_cuts = len(dl.dataset.cuts)
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -315,8 +339,7 @@ def decode_dataset(
if batch_idx % 100 == 0:
logging.info(
f"batch {batch_idx}, cuts processed until now is "
f"{num_cuts}/{tot_num_cuts} "
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
f"{num_cuts} "
)
return results

View File

@ -2,9 +2,12 @@ import logging
from typing import Dict, List, Optional, Tuple, Union
import k2
import kaldialign
import torch
import torch.nn as nn
from icefall.lexicon import Lexicon
def _intersect_device(
a_fsas: k2.Fsa,
@ -376,7 +379,7 @@ def rescore_with_n_best_list(
#
# num_repeats is also a k2.RaggedInt with 2 axes containing the
# multiplicities of each path.
# num_repeats.num_elements() == unique_word_seqs.num_elements()
# num_repeats.num_elements() == unique_word_seqs.tot_size(1)
#
# Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index
@ -549,6 +552,76 @@ def rescore_with_whole_lattice(
return ans
def nbest_oracle(
lattice: k2.Fsa,
num_paths: int,
ref_texts: List[str],
lexicon: Lexicon,
scale: float = 1.0,
) -> Dict[str, List[List[int]]]:
"""Select the best hypothesis given a lattice and a reference transcript.
The basic idea is to extract n paths from the given lattice, unique them,
and select the one that has the minimum edit distance with the corresponding
reference transcript as the decoding output.
The decoding result returned from this function is the best result that
we can obtain using n-best decoding with all kinds of rescoring techniques.
Args:
lattice:
An FsaVec. It can be the return value of :func:`get_lattice`.
Note: We assume its aux_labels contain word IDs.
num_paths:
The size of `n` in n-best.
ref_texts:
A list of reference transcript. Each entry contains space(s)
separated words
lexicon:
It is used to convert word IDs to word symbols.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Return:
Return a dict. Its key contains the information about the parameters
when calling this function, while its value contains the decoding output.
`len(ans_dict) == len(ref_texts)`
"""
saved_scores = lattice.scores.clone()
lattice.scores *= scale
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
lattice.scores = saved_scores
word_seq = k2.index(lattice.aux_labels, path)
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
unique_word_seq, _, _ = k2.ragged.unique_sequences(
word_seq, need_num_repeats=False, need_new2old_indexes=False
)
unique_word_ids = k2.ragged.to_list(unique_word_seq)
assert len(unique_word_ids) == len(ref_texts)
# unique_word_ids[i] contains all hypotheses of the i-th utterance
results = []
for hyps, ref in zip(unique_word_ids, ref_texts):
# Note hyps is a list-of-list ints
# Each sublist contains a hypothesis
ref_words = ref.strip().split()
# CAUTION: We don't convert ref_words to ref_words_ids
# since there may exist OOV words in ref_words
best_hyp_words = None
min_error = float("inf")
for hyp_words in hyps:
hyp_words = [lexicon.word_table[i] for i in hyp_words]
this_error = kaldialign.edit_distance(ref_words, hyp_words)["total"]
if this_error < min_error:
min_error = this_error
best_hyp_words = hyp_words
results.append(best_hyp_words)
return {f"nbest_{num_paths}_scale_{scale}_oracle": results}
def rescore_with_attention_decoder(
lattice: k2.Fsa,
num_paths: int,
@ -605,7 +678,7 @@ def rescore_with_attention_decoder(
#
# num_repeats is also a k2.RaggedInt with 2 axes containing the
# multiplicities of each path.
# num_repeats.num_elements() == unique_word_seqs.num_elements()
# num_repeats.num_elements() == unique_word_seqs.tot_size(1)
#
# Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index