From 52f1f6775de9dabc08e4aab2c97d8c4efbbed890 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Mar 2022 12:33:58 +0800 Subject: [PATCH] Update beam search to support max/log_add in selecting duplicate hyps. --- .../ASR/transducer_stateless/beam_search.py | 200 +++++++++++++++++- .../ASR/transducer_stateless/decode.py | 104 ++++++++- .../ASR/transducer_stateless/decoder.py | 1 + icefall/decode.py | 2 +- 4 files changed, 288 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 7b4fac31d..27ce9ddde 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -21,6 +21,153 @@ import k2 import torch from model import Transducer +from icefall.decode import Nbest, one_best_decoding +from icefall.utils import get_texts + + +def fast_beam_search( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + use_max: bool = False, +) -> List[List[int]]: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + use_max: + True to use max operation to select the hypothesis with the largest + log_prob when there are duplicate hypotheses; False to use log-add. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + B, T, C = encoder_out.shape + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(k2.RnntDecodingStream(decoding_graph)) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + tmp_len = torch.tensor([1]) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + # current_encoder_out is of shape + # (shape.NumElements(), 1, encoder_out_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1) + ) + # fmt: on + logits = model.joiner( + current_encoder_out, + decoder_out, + tmp_len.expand(decoder_out.size(0)), + tmp_len.expand(decoder_out.size(0)), + ) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + decoding_streams.terminate_and_flush_to_streams() + lattice = decoding_streams.format_output(encoder_out_lens.tolist()) + + if use_max: + best_path = one_best_decoding(lattice) + hyps = get_texts(best_path) + return hyps + else: + num_paths = 20 + use_double_scores = True + nbest_scale = 0.8 + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # The following code is modified from nbest.intersect() + word_fsa = k2.invert(nbest.fsa) + if hasattr(lattice, "aux_labels"): + # delete token IDs as it is not needed + del word_fsa.aux_labels + word_fsa.scores.zero_() + word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops( + word_fsa + ) + path_to_utt_map = nbest.shape.row_ids(1) + + if hasattr(lattice, "aux_labels"): + # lattice has token IDs as labels and word IDs as aux_labels. + # inv_lattice has word IDs as labels and token IDs as aux_labels + inv_lattice = k2.invert(lattice) + inv_lattice = k2.arc_sort(inv_lattice) + else: + inv_lattice = k2.arc_sort(lattice) + + if inv_lattice.shape[0] == 1: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=torch.zeros_like(path_to_utt_map), + sorted_match_a=True, + ) + else: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=path_to_utt_map, + sorted_match_a=True, + ) + + # path_lattice has word IDs as labels and token IDs as aux_labels + path_lattice = k2.top_sort(k2.connect(path_lattice)) + + tot_scores = path_lattice.get_tot_scores( + use_double_scores=True, log_semiring=True + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + best_hyp_indexes = ragged_tot_scores.argmax() + + best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) + hyps = get_texts(best_path) + return hyps + def greedy_search( model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int @@ -203,7 +350,7 @@ class HypothesisList(object): def data(self) -> Dict[str, Hypothesis]: return self._data - def add(self, hyp: Hypothesis) -> None: + def add(self, hyp: Hypothesis, use_max: bool = False) -> None: """Add a Hypothesis to `self`. If `hyp` already exists in `self`, its probability is updated using @@ -212,13 +359,20 @@ class HypothesisList(object): Args: hyp: The hypothesis to be added. + use_max: + True to select the hypothesis with the larger log_prob in case there + already exists a hypothesis whose `ys` equals to `hyp.ys`. + False to use log_add. """ key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + if use_max: + old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) + else: + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -415,6 +569,7 @@ def modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, + use_max: bool = False, ) -> List[List[int]]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcodded. @@ -425,6 +580,10 @@ def modified_beam_search( Output from the encoder. Its shape is (N, T, C). beam: Number of active paths during the beam search. + use_max: + If True, it uses max operation to select the hypothesis with the + larger log_prob in case two hypotheses have the same token sequences. + If False, use log add. Returns: Return a list-of-list of token IDs. ans[i] is the decoding results for the i-th utterance. @@ -443,7 +602,8 @@ def modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) + ), + use_max=use_max, ) encoder_out_len = torch.tensor([1]) @@ -519,7 +679,7 @@ def modified_beam_search( new_log_prob = topk_log_probs[k] new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B[i].add(new_hyp) + B[i].add(new_hyp, use_max=use_max) best_hyps = [b.get_most_probable(length_norm=True) for b in B] ans = [h.ys[context_size:] for h in best_hyps] @@ -531,6 +691,7 @@ def _deprecated_modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, + use_max: bool = False, ) -> List[int]: """It limits the maximum number of symbols per frame to 1. @@ -545,6 +706,10 @@ def _deprecated_modified_beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. + use_max: + If True, it uses max operation to select the hypothesis with the + larger log_prob in case two hypotheses have the same token sequences. + If False, use log add. Returns: Return the decoded result. """ @@ -565,7 +730,8 @@ def _deprecated_modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) + ), + use_max=use_max, ) encoder_out_len = torch.tensor([1]) @@ -624,7 +790,7 @@ def _deprecated_modified_beam_search( new_ys.append(new_token) new_log_prob = topk_log_probs[i] new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B.add(new_hyp) + B.add(new_hyp, use_max=use_max) best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks @@ -636,6 +802,7 @@ def beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, + use_max: bool = False, ) -> List[int]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -649,6 +816,10 @@ def beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. + use_max: + If True, it uses max operation to select the hypothesis with the + larger log_prob in case two hypotheses have the same token sequences. + If False, use log add. Returns: Return the decoded result. """ @@ -675,7 +846,8 @@ def beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) + ), + use_max=use_max, ) max_sym_per_utt = 20000 @@ -721,7 +893,10 @@ def beam_search( new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys - B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) + B.add( + Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob), + use_max=use_max, + ) # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) @@ -733,7 +908,10 @@ def beam_search( new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + values[idx] - A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + A.add( + Hypothesis(ys=new_ys, log_prob=new_log_prob), + use_max=use_max, + ) # Check whether B contains more than "beam" elements more probable # than the most probable in A diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index ac66c9b49..7be52b183 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -22,7 +22,8 @@ Usage: --epoch 14 \ --avg 7 \ --exp-dir ./transducer_stateless/exp \ - --max-duration 100 \ + --max-duration 1000 \ + --max-sym-per-frame 1 \ --decoding-method greedy_search (2) beam search @@ -30,7 +31,7 @@ Usage: --epoch 14 \ --avg 7 \ --exp-dir ./transducer_stateless/exp \ - --max-duration 100 \ + --max-duration 1000 \ --decoding-method beam_search \ --beam-size 4 @@ -39,9 +40,20 @@ Usage: --epoch 14 \ --avg 7 \ --exp-dir ./transducer_stateless/exp \ - --max-duration 100 \ + --max-duration 1000 \ --decoding-method modified_beam_search \ --beam-size 4 + +(4) fast beam search +./transducer_stateless/decode.py \ + --epoch 14 \ + --avg 7 \ + --exp-dir ./transducer_stateless/exp \ + --max-duration 1000 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 """ @@ -49,14 +61,16 @@ import argparse import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple +import k2 import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, + fast_beam_search, greedy_search, greedy_search_batch, modified_beam_search, @@ -68,6 +82,7 @@ from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -115,6 +130,7 @@ def get_parser(): - greedy_search - beam_search - modified_beam_search + - fast_beam_search """, ) @@ -126,6 +142,32 @@ def get_parser(): beam_search or modified_beam_search""", ) + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + parser.add_argument( "--context-size", type=int, @@ -141,6 +183,17 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--use-max", + type=str2bool, + default=False, + help="""If True, use max-op to select the hypothesis that have the + max log_prob in case of duplicate hypotheses. + If False, use log_add. + Used only for beam_search, modified_beam_search, and fast_beam_search + """, + ) + return parser @@ -149,6 +202,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -190,7 +244,18 @@ def decode_one_batch( ) hyp_list: List[List[int]] = [] - if ( + if params.decoding_method == "fast_beam_search": + hyp_list = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + use_max=params.use_max, + ) + elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 ): @@ -203,6 +268,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, beam=params.beam_size, + use_max=params.use_max, ) else: batch_size = encoder_out.size(0) @@ -221,6 +287,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, beam=params.beam_size, + use_max=params.use_max, ) else: raise ValueError( @@ -232,6 +299,14 @@ def decode_one_batch( if params.decoding_method == "greedy_search": return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } else: return {f"beam_{params.beam_size}": hyps} @@ -241,6 +316,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -281,6 +357,7 @@ def decode_dataset( model=model, sp=sp, batch=batch, + decoding_graph=decoding_graph, ) for name, hyps in hyps_dict.items(): @@ -358,15 +435,22 @@ def main(): params.update(vars(args)) assert params.decoding_method in ( - "greedy_search", "beam_search", + "fast_beam_search", + "greedy_search", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "beam_search" in params.decoding_method: + if "fast_beam_search" == params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"-use-max-{params.use_max}" + elif "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" + params.suffix += f"-use-max-{params.use_max}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -408,6 +492,11 @@ def main(): model.eval() model.device = device + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -428,6 +517,7 @@ def main(): params=params, model=model, sp=sp, + decoding_graph=decoding_graph, ) save_results( diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index b82fed37b..fbc2373a9 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -58,6 +58,7 @@ class Decoder(nn.Module): padding_idx=blank_id, ) self.blank_id = blank_id + self.vocab_size = vocab_size assert context_size >= 1, context_size self.context_size = context_size diff --git a/icefall/decode.py b/icefall/decode.py index d3e420eec..f60b1738d 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -218,7 +218,7 @@ class Nbest(object): # word_seq is a k2.RaggedTensor sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. - # It axes is [utt][path][word_id] + # Its axes are [utt][path][word_id] if isinstance(lattice.aux_labels, torch.Tensor): word_seq = k2.ragged.index(lattice.aux_labels, path) else: