diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 49004d2ba..d7e52da30 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -229,7 +229,11 @@ def greedy_search_batch( if emitted: # update decoder output decoder_input = [h[-context_size:] for h in hyps] - decoder_input = torch.tensor(decoder_input, device=device) + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.in64, + ) decoder_out = model.decoder(decoder_input, need_pad=False) ans = [h[context_size:] for h in hyps] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 811e74ad7..8e924bf96 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -192,7 +192,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --decoding_method is greedy_search""", ) diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index c5efb733d..6a82e99db 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -24,7 +24,7 @@ from model import Transducer def greedy_search( model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int ) -> List[int]: - """ + """Greedy search for a single utterance. Args: model: An instance of `Transducer`. @@ -80,7 +80,7 @@ def greedy_search( logits = model.joiner( current_encoder_out, decoder_out, encoder_out_len, decoder_out_len ) - # logits is (1, 1, 1, vocab_size) + # logits is (1, vocab_size) y = logits.argmax().item() if y != blank_id: @@ -101,6 +101,75 @@ def greedy_search( return hyp +def greedy_search_batch( + model: Transducer, encoder_out: torch.Tensor +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + Returns: + Return a list-of-list of token IDs containing the decoded results. + len(ans) equals to encoder_out.size(0). + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + device = model.device + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + hyps = [[blank_id] * context_size for _ in range(batch_size)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (batch_size, context_size) + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_out: (batch_size, 1, decoder_out_dim) + + encoder_out_len = torch.ones(batch_size, dtype=torch.int32) + decoder_out_len = torch.ones(batch_size, dtype=torch.int32) + + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + logits = model.joiner( + current_encoder_out, decoder_out, encoder_out_len, decoder_out_len + ) # (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) # (batch_size, context_size) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) # (batch_size, 1, decoder_out_dim) + + ans = [h[context_size:] for h in hyps] + return ans + + @dataclass class Hypothesis: # The predicted tokens so far. @@ -252,9 +321,11 @@ def run_decoder( device = model.device - decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [ys[-context_size:]], + device=device, + dtype=torch.int64, + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_cache[key] = decoder_out @@ -341,12 +412,6 @@ def modified_beam_search( device = model.device - decoder_input = torch.tensor( - [blank_id] * context_size, device=device - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - T = encoder_out.size(1) B = HypothesisList() diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index 6803d1721..8aba2289f 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -55,8 +55,13 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import beam_search, greedy_search, modified_beam_search -from train import get_transducer_model, get_params +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.utils import ( @@ -131,7 +136,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --decoding_method is greedy_search""", ) @@ -183,32 +188,47 @@ def decode_one_batch( encoder_out, encoder_out_lens = model.encoder( x=feature, x_lens=feature_lens ) - hyps = [] - batch_size = encoder_out.size(0) + hyp_list: List[List[int]] = [] - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) + if ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + else: + batch_size = encoder_out.size(0) + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + elif params.decoding_method == "modified_beam_search": + hyp = modified_beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyp_list.append(hyp) + + hyps = [sp.decode(hyp).split() for hyp in hyp_list] if params.decoding_method == "greedy_search": return {"greedy_search": hyps}