diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 414556d24..51ceac0f2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -779,74 +779,6 @@ def greedy_search_batch( ) -def deprecated_greedy_search_batch_for_cross_attn( - model: nn.Module, encoder_out: torch.Tensor -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - Returns: - Return a list-of-list of token IDs containing the decoded results. - len(ans) equals to encoder_out.size(0). - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - device = next(model.parameters()).device - - batch_size = encoder_out.size(0) - T = encoder_out.size(1) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - hyps = [[blank_id] * context_size for _ in range(batch_size)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (batch_size, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - encoder_out = model.joiner.encoder_proj(encoder_out) - - # decoder_out: (batch_size, 1, decoder_out_dim) - for t in range(T): - current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits'shape (batch_size, 1, 1, vocab_size) - logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v not in (blank_id, unk_id): - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - ans = [h[context_size:] for h in hyps] - return ans - - def deprecated_greedy_search_batch_for_cross_attn( model: nn.Module, encoder_out: torch.Tensor,