From 634931cb61134fef2d70b3f09d80b395fd94ec16 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Thu, 27 Jul 2023 17:52:49 +0800 Subject: [PATCH] minor updates --- .../beam_search.py | 86 ++++++++++++++++++- .../ASR/zipformer_label_level_algn/joiner.py | 5 +- .../ASR/zipformer_label_level_algn/model.py | 7 +- 3 files changed, 91 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 16279d05d..13509021e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -779,6 +779,78 @@ def greedy_search_batch( ) +def deprecated_greedy_search_batch( + model: nn.Module, encoder_out: torch.Tensor +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + Returns: + Return a list-of-list of token IDs containing the decoded results. + len(ans) equals to encoder_out.size(0). + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + device = next(model.parameters()).device + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + hyps = [[blank_id] * context_size for _ in range(batch_size)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (batch_size, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + encoder_out = model.joiner.encoder_proj(encoder_out) + + # decoder_out: (batch_size, 1, decoder_out_dim) + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) + print(current_encoder_out) + print(decoder_out.unsqueeze(1)) + print(logits) + exit() + # logits'shape (batch_size, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v not in (blank_id, unk_id): + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + ans = [h[context_size:] for h in hyps] + return ans + + def deprecated_greedy_search_batch_for_cross_attn( model: nn.Module, encoder_out: torch.Tensor, @@ -822,23 +894,29 @@ def deprecated_greedy_search_batch_for_cross_attn( # decoder_out: (batch_size, 1, decoder_out_dim) for t in range(T): - # current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - current_encoder_out = model.joiner.label_level_am_attention( + attn_encoder_out = model.joiner.label_level_am_attention( encoder_out.unsqueeze(2), decoder_out.unsqueeze(2), - # encoder_out_lens, - None, + encoder_out_lens, ) + # print(encoder_out[:, t : t + 1, :].unsqueeze(2)) + # current_encoder_out = torch.zeros_like(current_encoder_out) logits = model.joiner( current_encoder_out, decoder_out.unsqueeze(1), + torch.zeros_like(current_encoder_out), None, apply_attn=False, project_input=False, ) # logits'shape (batch_size, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) + # print(current_encoder_out) + # print(decoder_out) + # print(logits) + # # exit() assert logits.ndim == 2, logits.shape y = logits.argmax(dim=1).tolist() emitted = False diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py b/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py index 040f4f40e..205819544 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py @@ -39,6 +39,7 @@ class Joiner(nn.Module): self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, + attn_encoder_out: torch.Tensor, lengths: torch.Tensor, apply_attn: bool = True, project_input: bool = True, @@ -64,14 +65,14 @@ class Joiner(nn.Module): ) if apply_attn and lengths is not None: - encoder_out = self.label_level_am_attention( + attn_encoder_out = self.label_level_am_attention( encoder_out, decoder_out, lengths ) if project_input: logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: - logit = encoder_out + decoder_out + logit = encoder_out + decoder_out + attn_encoder_out logit = self.output_linear(torch.tanh(logit)) diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/model.py b/egs/librispeech/ASR/zipformer_label_level_algn/model.py index c502224f6..53abdd21a 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/model.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/model.py @@ -265,7 +265,12 @@ class AsrModel(nn.Module): # project_input=False since we applied the decoder's input projections # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner( - am_pruned, lm_pruned, encoder_out_lens, apply_attn=True, project_input=False + am_pruned, + lm_pruned, + None, + encoder_out_lens, + apply_attn=True, + project_input=False, ) with torch.cuda.amp.autocast(enabled=False):