diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 889a0a474..470bdd682 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -21,6 +21,7 @@ from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.decode import ( get_lattice, nbest_decoding, + nbest_oracle, one_best_decoding, rescore_with_attention_decoder, rescore_with_n_best_list, @@ -56,6 +57,15 @@ def get_parser(): "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", ) + + parser.add_argument( + "--scale", + type=float, + default=1.0, + help="The scale to be applied to `lattice.scores`." + "A smaller value results in more unique paths", + ) + return parser @@ -85,10 +95,12 @@ def get_params() -> AttributeDict: # - nbest-rescoring # - whole-lattice-rescoring # - attention-decoder + # - nbest-oracle # "method": "whole-lattice-rescoring", "method": "attention-decoder", + # "method": "nbest-oracle", # num_paths is used when method is "nbest", "nbest-rescoring", - # and attention-decoder + # attention-decoder, and nbest-oracle "num_paths": 100, } ) @@ -179,6 +191,19 @@ def decode_one_batch( subsampling_factor=params.subsampling_factor, ) + if params.method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is slightly worse than that of rescored lattices. + return nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + lexicon=lexicon, + scale=params.scale, + ) + if params.method in ["1best", "nbest"]: if params.method == "1best": best_path = one_best_decoding( @@ -284,7 +309,6 @@ def decode_dataset( results = [] num_cuts = 0 - tot_num_cuts = len(dl.dataset.cuts) results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -315,8 +339,7 @@ def decode_dataset( if batch_idx % 100 == 0: logging.info( f"batch {batch_idx}, cuts processed until now is " - f"{num_cuts}/{tot_num_cuts} " - f"({float(num_cuts)/tot_num_cuts*100:.6f}%)" + f"{num_cuts} " ) return results diff --git a/icefall/decode.py b/icefall/decode.py index 0e9baf2e4..43524a02a 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -2,9 +2,12 @@ import logging from typing import Dict, List, Optional, Tuple, Union import k2 +import kaldialign import torch import torch.nn as nn +from icefall.lexicon import Lexicon + def _intersect_device( a_fsas: k2.Fsa, @@ -376,7 +379,7 @@ def rescore_with_n_best_list( # # num_repeats is also a k2.RaggedInt with 2 axes containing the # multiplicities of each path. - # num_repeats.num_elements() == unique_word_seqs.num_elements() + # num_repeats.num_elements() == unique_word_seqs.tot_size(1) # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index @@ -549,6 +552,76 @@ def rescore_with_whole_lattice( return ans +def nbest_oracle( + lattice: k2.Fsa, + num_paths: int, + ref_texts: List[str], + lexicon: Lexicon, + scale: float = 1.0, +) -> Dict[str, List[List[int]]]: + """Select the best hypothesis given a lattice and a reference transcript. + + The basic idea is to extract n paths from the given lattice, unique them, + and select the one that has the minimum edit distance with the corresponding + reference transcript as the decoding output. + + The decoding result returned from this function is the best result that + we can obtain using n-best decoding with all kinds of rescoring techniques. + + Args: + lattice: + An FsaVec. It can be the return value of :func:`get_lattice`. + Note: We assume its aux_labels contain word IDs. + num_paths: + The size of `n` in n-best. + ref_texts: + A list of reference transcript. Each entry contains space(s) + separated words + lexicon: + It is used to convert word IDs to word symbols. + scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + Return: + Return a dict. Its key contains the information about the parameters + when calling this function, while its value contains the decoding output. + `len(ans_dict) == len(ref_texts)` + """ + saved_scores = lattice.scores.clone() + + lattice.scores *= scale + path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + lattice.scores = saved_scores + + word_seq = k2.index(lattice.aux_labels, path) + word_seq = k2.ragged.remove_values_leq(word_seq, 0) + unique_word_seq, _, _ = k2.ragged.unique_sequences( + word_seq, need_num_repeats=False, need_new2old_indexes=False + ) + unique_word_ids = k2.ragged.to_list(unique_word_seq) + assert len(unique_word_ids) == len(ref_texts) + # unique_word_ids[i] contains all hypotheses of the i-th utterance + + results = [] + for hyps, ref in zip(unique_word_ids, ref_texts): + # Note hyps is a list-of-list ints + # Each sublist contains a hypothesis + ref_words = ref.strip().split() + # CAUTION: We don't convert ref_words to ref_words_ids + # since there may exist OOV words in ref_words + best_hyp_words = None + min_error = float("inf") + for hyp_words in hyps: + hyp_words = [lexicon.word_table[i] for i in hyp_words] + this_error = kaldialign.edit_distance(ref_words, hyp_words)["total"] + if this_error < min_error: + min_error = this_error + best_hyp_words = hyp_words + results.append(best_hyp_words) + + return {f"nbest_{num_paths}_scale_{scale}_oracle": results} + + def rescore_with_attention_decoder( lattice: k2.Fsa, num_paths: int, @@ -605,7 +678,7 @@ def rescore_with_attention_decoder( # # num_repeats is also a k2.RaggedInt with 2 axes containing the # multiplicities of each path. - # num_repeats.num_elements() == unique_word_seqs.num_elements() + # num_repeats.num_elements() == unique_word_seqs.tot_size(1) # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index