From 18b8c670dd6dbe31d8fa54791d7c14b32638b581 Mon Sep 17 00:00:00 2001 From: luomingshuang <739314837@qq.com> Date: Thu, 10 Mar 2022 10:24:23 +0800 Subject: [PATCH] do some changes --- .../ASR/pruned_transducer_stateless/decoder.py | 15 +++++++-------- .../ASR/pruned_transducer_stateless/train.py | 2 +- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index 06751ae06..62ecb16d0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -38,7 +38,7 @@ class Decoder(nn.Module): embedding_dim: int, blank_id: int, context_size: int, - modify_embedding_layer: bool, + modify_embedding_layer: bool = False, ): """ Args: @@ -52,7 +52,7 @@ class Decoder(nn.Module): Number of previous words to use to predict the next word. 1 means bigram; 2 means trigram. n means (n+1)-gram. modify_embedding_layer: - Do or not do modify the decoder embedding layer. + Decide to modify the decoder embedding layer or not. """ super().__init__() self.embedding = nn.Embedding( @@ -61,13 +61,12 @@ class Decoder(nn.Module): padding_idx=blank_id, ) self.modify_embedding_layer = modify_embedding_layer + self.embedding_dim_inv_sqrt = embedding_dim ** 0.5 if self.modify_embedding_layer: - self.embedding.weight = nn.Parameter( - ( - torch.randn(vocab_size, embedding_dim) - * (embedding_dim ** -0.5) - ).detach() + embedding_weight = torch.randn(vocab_size, embedding_dim) * ( + 1 / self.embedding_dim_inv_sqrt ) + self.embedding.weight = nn.Parameter(embedding_weight) self.embedding_dim = embedding_dim self.blank_id = blank_id @@ -98,7 +97,7 @@ class Decoder(nn.Module): """ embedding_out = self.embedding(y) if self.modify_embedding_layer: - embedding_out = self.embedding(y) * (self.embedding_dim ** 0.5) + embedding_out = self.embedding(y) * self.embedding_dim_inv_sqrt if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 06eb58eb4..4b18eefc1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -182,7 +182,7 @@ def get_parser(): parser.add_argument( "--modify-embedding", type=bool, - default=True, + default=False, help="When True, we modify the decoder embedding layer." "When False, we don't modify the decoder embedding layer.", )