diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index 3d4e69a4b..06751ae06 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -38,6 +38,7 @@ class Decoder(nn.Module): embedding_dim: int, blank_id: int, context_size: int, + modify_embedding_layer: bool, ): """ Args: @@ -50,6 +51,8 @@ class Decoder(nn.Module): context_size: Number of previous words to use to predict the next word. 1 means bigram; 2 means trigram. n means (n+1)-gram. + modify_embedding_layer: + Do or not do modify the decoder embedding layer. """ super().__init__() self.embedding = nn.Embedding( @@ -57,6 +60,16 @@ class Decoder(nn.Module): embedding_dim=embedding_dim, padding_idx=blank_id, ) + self.modify_embedding_layer = modify_embedding_layer + if self.modify_embedding_layer: + self.embedding.weight = nn.Parameter( + ( + torch.randn(vocab_size, embedding_dim) + * (embedding_dim ** -0.5) + ).detach() + ) + + self.embedding_dim = embedding_dim self.blank_id = blank_id assert context_size >= 1, context_size @@ -84,6 +97,9 @@ class Decoder(nn.Module): Return a tensor of shape (N, U, embedding_dim). """ embedding_out = self.embedding(y) + if self.modify_embedding_layer: + embedding_out = self.embedding(y) * (self.embedding_dim ** 0.5) + if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index f0ea2ccaa..06eb58eb4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -179,6 +179,14 @@ def get_parser(): "with this parameter before adding to the final loss.", ) + parser.add_argument( + "--modify-embedding", + type=bool, + default=True, + help="When True, we modify the decoder embedding layer." + "When False, we don't modify the decoder embedding layer.", + ) + parser.add_argument( "--seed", type=int, @@ -284,6 +292,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: embedding_dim=params.embedding_dim, blank_id=params.blank_id, context_size=params.context_size, + modify_embedding_layer=params.modify_embedding, ) return decoder