From c3b3123b27d8da5dfd0fdad0095c7db54b406191 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 28 Jan 2022 16:20:42 +0800 Subject: [PATCH] Add modified beam search. --- .../ASR/transducer_stateless/beam_search.py | 84 +++++++++++++++++++ .../ASR/transducer_stateless/decode.py | 21 +++-- 2 files changed, 100 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 341c74fab..fa47698f0 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -312,6 +312,90 @@ def run_joiner( return log_prob +def modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[int]: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + decoder_input = torch.tensor( + [blank_id] * context_size, device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + + T = encoder_out.size(1) + + B = HypothesisList() + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + + encoder_out_len = torch.tensor([1]) + decoder_out_len = torch.tensor([1]) + + decoder_cache: Dict[str, torch.Tensor] = {} + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :] + # fmt: on + A = B + B = HypothesisList() + + joint_cache: Dict[str, torch.Tensor] = {} + + for hyp in A: + decoder_out = run_decoder( + ys=hyp.ys, model=model, decoder_cache=decoder_cache + ) + key = "_".join(map(str, hyp.ys[-context_size:])) + key += f"-t-{t}" + log_prob = run_joiner( + key=key, + model=model, + encoder_out=current_encoder_out, + decoder_out=decoder_out, + encoder_out_len=encoder_out_len, + decoder_out_len=decoder_out_len, + joint_cache=joint_cache, + ) + log_prob = log_prob.cpu().tolist() + + for i, v in enumerate(log_prob): + if i == blank_id: + # Use [:] to make a copy + new_ys = hyp.ys[:] + else: + new_ys = hyp.ys + [i] + new_hyp = Hypothesis(ys=new_ys, log_prob=hyp.log_prob + v) + B.add(new_hyp) + B = B.topk(beam) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + return ys + + def beam_search( model: Transducer, encoder_out: torch.Tensor, diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index e5987b75e..c101d9397 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -46,7 +46,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import beam_search, greedy_search +from beam_search import beam_search, greedy_search, modified_beam_search from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -104,6 +104,7 @@ def get_parser(): help="""Possible values are: - greedy_search - beam_search + - modified_beam_search """, ) @@ -111,7 +112,8 @@ def get_parser(): "--beam-size", type=int, default=4, - help="Used only when --decoding-method is beam_search", + help="""Used only when --decoding-method is + beam_search or modified_beam_search""", ) parser.add_argument( @@ -125,7 +127,8 @@ def get_parser(): "--max-sym-per-frame", type=int, default=3, - help="Maximum number of symbols per frame", + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", ) return parser @@ -256,6 +259,10 @@ def decode_one_batch( hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) + elif params.decoding_method == "modified_beam_search": + hyp = modified_beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" @@ -389,11 +396,15 @@ def main(): params = get_params() params.update(vars(args)) - assert params.decoding_method in ("greedy_search", "beam_search") + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + ) params.res_dir = params.exp_dir / params.decoding_method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.decoding_method == "beam_search": + if "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}"