diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index ce8b04afd..b89cd7638 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -279,7 +279,7 @@ def greedy_search( device = next(model.parameters()).device decoder_input = torch.tensor( - [blank_id] * context_size, device=device, dtype=torch.int64 + [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64 ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -373,7 +373,7 @@ def greedy_search_batch( assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) - hyps = [[blank_id] * context_size for _ in range(N)] + hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)] decoder_input = torch.tensor( hyps, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py index a1c755d73..3875ab9d6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -85,7 +85,9 @@ class Decoder(nn.Module): Return a tensor of shape (N, U, decoder_dim). """ y = y.to(torch.int64) - embedding_out = self.embedding(y) + # this stuff about clamp() is a temporary fix for a mismatch + # at utterance start, we use negative ids in beam_search.py + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: