From d5c78a2238a4b6baeddf2f741c23cd0de11f5d2f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 22 Mar 2022 10:32:22 +0800 Subject: [PATCH] Implement greedy search in batch mode for transducer decoding. (#262) --- .../beam_search.py | 60 ++++++++++++++++++- .../ASR/pruned_transducer_stateless/decode.py | 11 ++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 651854999..05b027214 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -106,7 +106,7 @@ def fast_beam_search( def greedy_search( model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int ) -> List[int]: - """ + """Greedy search for a single utterance. Args: model: An instance of `Transducer`. @@ -178,6 +178,64 @@ def greedy_search( return hyp +def greedy_search_batch( + model: Transducer, encoder_out: torch.Tensor +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + Returns: + Return a list-of-list integers containing the decoded results. + len(ans) equals to encoder_out.size(0). + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + device = model.device + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + hyps = [[blank_id] * context_size for _ in range(batch_size)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (batch_size, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_out: (batch_size, 1, decoder_out_dim) + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) + # logits'shape (batch_size, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps] + decoder_input = torch.tensor(decoder_input, device=device) + decoder_out = model.decoder(decoder_input, need_pad=False) + + ans = [h[context_size:] for h in hyps] + return ans + + @dataclass class Hypothesis: # The predicted tokens so far. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index ad76411c0..c43af9741 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -71,6 +71,7 @@ from beam_search import ( beam_search, fast_beam_search, greedy_search, + greedy_search_batch, modified_beam_search, ) from train import get_params, get_transducer_model @@ -261,6 +262,16 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0)