diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index 62ecb16d0..1074357cd 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -52,7 +52,7 @@ class Decoder(nn.Module): Number of previous words to use to predict the next word. 1 means bigram; 2 means trigram. n means (n+1)-gram. modify_embedding_layer: - Decide to modify the decoder embedding layer or not. + Whether modify the decoder embedding layer or not. """ super().__init__() self.embedding = nn.Embedding( @@ -61,12 +61,18 @@ class Decoder(nn.Module): padding_idx=blank_id, ) self.modify_embedding_layer = modify_embedding_layer - self.embedding_dim_inv_sqrt = embedding_dim ** 0.5 + self.embedding_dim_sqrt = embedding_dim ** 0.5 + embedding_weight = torch.randn(vocab_size, embedding_dim) * ( + 1 / self.embedding_dim_sqrt + ) if self.modify_embedding_layer: - embedding_weight = torch.randn(vocab_size, embedding_dim) * ( - 1 / self.embedding_dim_inv_sqrt + self.embedding = nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=embedding_dim, + padding_idx=blank_id, + _weight=embedding_weight, ) - self.embedding.weight = nn.Parameter(embedding_weight) + self.embedding._fill_padding_idx_with_zero() self.embedding_dim = embedding_dim self.blank_id = blank_id @@ -97,7 +103,7 @@ class Decoder(nn.Module): """ embedding_out = self.embedding(y) if self.modify_embedding_layer: - embedding_out = self.embedding(y) * self.embedding_dim_inv_sqrt + embedding_out = self.embedding(y) * self.embedding_dim_sqrt if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1)