From b1c3705fbe0ab75f2ffe5dcd4a0611e4dd6343c2 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 24 Apr 2022 15:10:30 +0800 Subject: [PATCH 1/2] Compute the Nbest oracle WER for RNN-T decoding. --- .../beam_search.py | 173 +++++++++++++++++- .../pruned_transducer_stateless3/decode.py | 74 +++++++- 2 files changed, 229 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 86e34be61..10bd7bf7e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -22,11 +22,11 @@ import k2 import torch from model import Transducer -from icefall.decode import one_best_decoding +from icefall.decode import Nbest, one_best_decoding from icefall.utils import get_texts -def fast_beam_search( +def fast_beam_search_one_best( model: Transducer, decoding_graph: k2.Fsa, encoder_out: torch.Tensor, @@ -37,6 +37,9 @@ def fast_beam_search( ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. + A lattice is first obtained using modified beam search, and then + the shortest path within the lattice is used as the final output. + Args: model: An instance of `Transducer`. @@ -56,6 +59,153 @@ def fast_beam_search( Returns: Return the decoded result. """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + ) + + best_path = one_best_decoding(lattice) + hyps = get_texts(best_path) + return hyps + + +def fast_beam_search_nbest_oracle( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + ref_texts: List[List[int]], + use_double_scores: bool = True, + nbest_scale: float = 0.5, +) -> List[List[int]]: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using modified beam search, and then + we select `num_paths` linear paths from the lattice. The path + that has the minimum edit distance with the given reference transcript + is used as the output. + + This is the best result we can achieve for any nbest based rescoring + methods. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + ref_texts: + A list-of-list of integers containing the reference transcripts. + If the decoding_graph is a trivial_graph, the integer ID is the + BPE token ID. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + + Returns: + Return the decoded result. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # We assume the labels of nbest.fsa are token IDs and the aux_labels + # are word IDs. + word_fsa = k2.invert(nbest.fsa) + word_ids = get_texts(word_fsa, return_ragged=True) + + hyps = k2.levenshtein_graph(word_ids) + refs = k2.levenshtein_graph(ref_texts, device=hyps.device) + + levenshtein_alignment = k2.levenshtein_alignment( + refs=refs, + hyps=hyps, + hyp_to_ref_map=nbest.shape.row_ids(1), + sorted_match_ref=True, + ) + + tot_scores = levenshtein_alignment.get_tot_scores( + use_double_scores=False, log_semiring=False + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + + max_indexes = ragged_tot_scores.argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + hyps = get_texts(best_path) + return hyps + + +def fast_beam_search( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, +) -> k2.Fsa: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + Returns: + Return an FsaVec with axes [utt][state][arc] containing the decoded + lattice. Note: When the input graph is a TrivialGraph, the returned + lattice is actually an acceptor. + """ assert encoder_out.ndim == 3 context_size = model.decoder.context_size @@ -104,9 +254,7 @@ def fast_beam_search( decoding_streams.terminate_and_flush_to_streams() lattice = decoding_streams.format_output(encoder_out_lens.tolist()) - best_path = one_best_decoding(lattice) - hyps = get_texts(best_path) - return hyps + return lattice def greedy_search( @@ -131,6 +279,7 @@ def greedy_search( blank_id = model.decoder.blank_id context_size = model.decoder.context_size + unk_id = getattr(model, "unk_id", blank_id) device = model.device @@ -171,7 +320,7 @@ def greedy_search( # logits is (1, 1, 1, vocab_size) y = logits.argmax().item() - if y != blank_id: + if y not in (blank_id, unk_id): hyp.append(y) decoder_input = torch.tensor( [hyp[-context_size:]], device=device @@ -212,6 +361,7 @@ def greedy_search_batch( T = encoder_out.size(1) blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size hyps = [[blank_id] * context_size for _ in range(batch_size)] @@ -240,7 +390,7 @@ def greedy_search_batch( y = logits.argmax(dim=1).tolist() emitted = False for i, v in enumerate(y): - if v != blank_id: + if v not in (blank_id, unk_id): hyps[i].append(v) emitted = True if emitted: @@ -433,6 +583,7 @@ def modified_beam_search( T = encoder_out.size(1) blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size device = model.device B = [HypothesisList() for _ in range(batch_size)] @@ -515,7 +666,7 @@ def modified_beam_search( new_ys = hyp.ys[:] new_token = topk_token_indexes[k] - if new_token != blank_id: + if new_token not in (blank_id, unk_id): new_ys.append(new_token) new_log_prob = topk_log_probs[k] @@ -556,6 +707,7 @@ def _deprecated_modified_beam_search( # support only batch_size == 1 for now assert encoder_out.size(0) == 1, encoder_out.size(0) blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size device = model.device @@ -626,7 +778,7 @@ def _deprecated_modified_beam_search( hyp = A[topk_hyp_indexes[i]] new_ys = hyp.ys[:] new_token = topk_token_indexes[i] - if new_token != blank_id: + if new_token not in (blank_id, unk_id): new_ys.append(new_token) new_log_prob = topk_log_probs[i] new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) @@ -663,6 +815,7 @@ def beam_search( # support only batch_size == 1 for now assert encoder_out.size(0) == 1, encoder_out.size(0) blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size device = model.device @@ -748,7 +901,7 @@ def beam_search( # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) for i, v in zip(indices.tolist(), values.tolist()): - if i == blank_id: + if i in (blank_id, unk_id): continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index bbc51301f..44bcc2843 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -69,7 +69,8 @@ import torch.nn as nn from asr_datamodule import AsrDataModule from beam_search import ( beam_search, - fast_beam_search, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, greedy_search, greedy_search_batch, modified_beam_search, @@ -145,6 +146,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_nbest_oracle """, ) @@ -164,7 +166,8 @@ def get_parser(): help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", + Used only when --decoding-method is + fast_beam_search or fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -172,7 +175,7 @@ def get_parser(): type=int, default=4, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search or fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -180,7 +183,7 @@ def get_parser(): type=int, default=8, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search or fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -198,6 +201,23 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for computed nbest oracle WER + when the decoding method is fast_beam_search_nbest_oracle. + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding_method is fast_beam_search_nbest_oracle. + """, + ) return parser @@ -231,7 +251,8 @@ def decode_one_batch( for the format of the `batch`. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is + fast_beam_search or fast_beam_search_nbest_oracle. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -252,7 +273,7 @@ def decode_one_batch( hyps = [] if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search( + hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -263,6 +284,21 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 @@ -316,6 +352,16 @@ def decode_one_batch( f"max_states_{params.max_states}" ): hyps } + elif params.decoding_method == "fast_beam_search_nbest_oracle": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}_" + f"num_paths_{params.num_paths}_" + f"nbest_scale_{params.nbest_scale}" + ): hyps + } else: return {f"beam_size_{params.beam_size}": hyps} @@ -450,15 +496,22 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_nbest_oracle", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search": params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" + elif params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"-num-paths-{params.num_paths}" + params.suffix += f"-nbest-scale-{params.nbest_scale}" elif "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" else: @@ -479,6 +532,7 @@ def main(): # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.unk_id = sp.unk_id() params.vocab_size = sp.get_piece_size() logging.info(params) @@ -506,8 +560,12 @@ def main(): model.to(device) model.eval() model.device = device + model.unk_id = params.unk_id - if params.decoding_method == "fast_beam_search": + if params.decoding_method in ( + "fast_beam_search", + "fast_beam_search_nbest_oracle", + ): decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None From b54d9a256d936f71d751ff7d5b4b4c02216bcf2a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 24 Apr 2022 15:25:34 +0800 Subject: [PATCH 2/2] Minor fixes. --- .../ASR/pruned_transducer_stateless2/beam_search.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 10bd7bf7e..ad492aaa5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -146,12 +146,7 @@ def fast_beam_search_nbest_oracle( nbest_scale=nbest_scale, ) - # We assume the labels of nbest.fsa are token IDs and the aux_labels - # are word IDs. - word_fsa = k2.invert(nbest.fsa) - word_ids = get_texts(word_fsa, return_ragged=True) - - hyps = k2.levenshtein_graph(word_ids) + hyps = nbest.build_levenshtein_graphs() refs = k2.levenshtein_graph(ref_texts, device=hyps.device) levenshtein_alignment = k2.levenshtein_alignment(