From afec6b6cae7fab2bd5638923b477df15c9892853 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 21 Dec 2021 11:37:56 +0800 Subject: [PATCH] Update greedy search for modified decoder. --- .../ASR/transducer_stateless/beam_search.py | 25 +++++++++++-------- .../ASR/transducer_stateless/decoder.py | 4 +-- .../ASR/transducer_stateless/test_decoder.py | 3 +-- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 88f23e922..056e7c372 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -36,13 +36,17 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: # support only batch_size == 1 for now assert encoder_out.size(0) == 1, encoder_out.size(0) blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device - sos = torch.tensor([blank_id], device=device).reshape(1, 1) - decoder_out = model.decoder(sos) + sos = torch.tensor([blank_id] * context_size, device=device).reshape( + 1, context_size + ) + decoder_out = model.decoder(sos, need_pad=False) T = encoder_out.size(1) t = 0 - hyp = [] + hyp = [blank_id] * context_size sym_per_frame = 0 sym_per_utt = 0 @@ -57,14 +61,14 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: logits = model.joiner(current_encoder_out, decoder_out) # logits is (1, 1, 1, vocab_size) - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - # TODO: Use logits.argmax() - y = log_prob.argmax() + y = logits.argmax().item() if y != blank_id: - hyp.append(y.item()) - y = y.reshape(1, 1) - decoder_out = model.decoder(y) + hyp.append(y) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) sym_per_utt += 1 sym_per_frame += 1 @@ -72,6 +76,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: if y == blank_id or sym_per_frame > max_sym_per_frame: sym_per_frame = 0 t += 1 + hyp = hyp[context_size:] # remove blanks return hyp diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index 0773ce37b..4b1ec6ee6 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -70,7 +70,7 @@ class Decoder(nn.Module): bias=False, ) - def forward(self, y: torch.Tensor) -> torch.Tensor: + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ Args: y: @@ -81,7 +81,7 @@ class Decoder(nn.Module): embeding_out = self.embedding(y) if self.context_size > 1: embeding_out = embeding_out.permute(0, 2, 1) - if self.training is True: + if need_pad is True: embeding_out = F.pad( embeding_out, pad=(self.context_size - 1, 0) ) diff --git a/egs/librispeech/ASR/transducer_stateless/test_decoder.py b/egs/librispeech/ASR/transducer_stateless/test_decoder.py index 532aaf776..fa6632eb7 100755 --- a/egs/librispeech/ASR/transducer_stateless/test_decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/test_decoder.py @@ -47,9 +47,8 @@ def test_decoder(): assert y.shape == (N, U, embedding_dim) # for inference - decoder.eval() x = torch.randint(low=0, high=vocab_size, size=(N, context_size)) - y = decoder(x) + y = decoder(x, need_pad=False) assert y.shape == (N, 1, embedding_dim)