diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index fc1285dc7..9cf86ed60 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -335,7 +335,9 @@ def greedy_search( def greedy_search_batch( - model: Transducer, encoder_out: torch.Tensor + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, ) -> List[List[int]]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: @@ -343,6 +345,9 @@ def greedy_search_batch( The transducer model. encoder_out: Output from the encoder. Its shape is (N, T, C), where N >= 1. + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. Returns: Return a list-of-list of token IDs containing the decoded results. len(ans) equals to encoder_out.size(0). @@ -350,31 +355,49 @@ def greedy_search_batch( assert encoder_out.ndim == 3 assert encoder_out.size(0) >= 1, encoder_out.size(0) - device = next(model.parameters()).device + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) - batch_size = encoder_out.size(0) - T = encoder_out.size(1) + device = next(model.parameters()).device blank_id = model.decoder.blank_id unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - hyps = [[blank_id] * context_size for _ in range(batch_size)] + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] decoder_input = torch.tensor( hyps, device=device, dtype=torch.int64, - ) # (batch_size, context_size) + ) # (N, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) - encoder_out = model.joiner.encoder_proj(encoder_out) + # decoder_out: (N, 1, decoder_out_dim) - # decoder_out: (batch_size, 1, decoder_out_dim) - for t in range(T): - current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + logits = model.joiner( current_encoder_out, decoder_out.unsqueeze(1), project_input=False ) @@ -390,7 +413,7 @@ def greedy_search_batch( emitted = True if emitted: # update decoder output - decoder_input = [h[-context_size:] for h in hyps] + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] decoder_input = torch.tensor( decoder_input, device=device, @@ -399,7 +422,12 @@ def greedy_search_batch( decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) - ans = [h[context_size:] for h in hyps] + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + return ans @@ -557,6 +585,7 @@ def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: def modified_beam_search( model: Transducer, encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, beam: int = 4, ) -> List[List[int]]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. @@ -566,6 +595,9 @@ def modified_beam_search( The transducer model. encoder_out: Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. beam: Number of active paths during the beam search. Returns: @@ -573,16 +605,26 @@ def modified_beam_search( for the i-th utterance. """ assert encoder_out.ndim == 3, encoder_out.shape - - batch_size = encoder_out.size(0) - T = encoder_out.size(1) + assert encoder_out.size(0) >= 1, encoder_out.size(0) + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) blank_id = model.decoder.blank_id unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size device = next(model.parameters()).device - B = [HypothesisList() for _ in range(batch_size)] - for i in range(batch_size): + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): B[i].add( Hypothesis( ys=[blank_id] * context_size, @@ -590,11 +632,20 @@ def modified_beam_search( ) ) - encoder_out = model.joiner.encoder_proj(encoder_out) + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - for t in range(T): - current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] hyps_shape = _get_hyps_shape(B).to(device) @@ -668,8 +719,14 @@ def modified_beam_search( new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) B[i].add(new_hyp) - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - ans = [h.ys[context_size:] for h in best_hyps] + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=False) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 5d946003a..97ae79845 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -270,6 +270,7 @@ def decode_one_batch( hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -277,6 +278,7 @@ def decode_one_batch( hyp_tokens = modified_beam_search( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) for hyp in sp.decode(hyp_tokens): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 9a6b5a117..346ba4471 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -307,6 +307,7 @@ def decode_one_batch( hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -314,6 +315,7 @@ def decode_one_batch( hyp_tokens = modified_beam_search( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) for hyp in sp.decode(hyp_tokens): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 1f4a22213..4097e55f5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -284,6 +284,7 @@ def decode_one_batch( hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -291,6 +292,7 @@ def decode_one_batch( hyp_tokens = modified_beam_search( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) for hyp in sp.decode(hyp_tokens):