diff --git a/egs/librispeech/ASR/pruned_transducer_stateless9/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless9/decoder.py index bb460a85e..b89173158 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless9/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless9/decoder.py @@ -79,7 +79,6 @@ class Decoder(nn.Module): self, embedding_out: torch.Tensor, k: torch.Tensor, - is_training: bool = True, ) -> torch.Tensor: """ Add the repeat parameter to the embedding_out tensor. @@ -89,16 +88,13 @@ class Decoder(nn.Module): A tensor of shape (N, U, decoder_dim). k: A tensor of shape (N, U). - Should be (N, S) during training. + Should be (N, S + 1) during training. Should be (N, 1) during inference. is_training: Whether it is training. Returns: Return a tensor of shape (N, U, decoder_dim). """ - if is_training: - k = F.pad(k, (1, 0), mode="constant", value=self.blank_id) - return embedding_out + (k / (1 + k)).unsqueeze(2) * self.repeat_param def forward( @@ -140,7 +136,6 @@ class Decoder(nn.Module): embedding_out = self._add_repeat_param( embedding_out=embedding_out, k=k, - is_training=need_pad, ) embedding_out = F.relu(embedding_out) return embedding_out