From c82db4184a395177f4c2a79f1f20d7d3508777b2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 15:50:11 +0800 Subject: [PATCH] Remove xscale from pos_embedding --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 6 +++--- egs/librispeech/ASR/transducer_stateless/conformer.py | 2 -- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 7c7d0ee6c..867ababf2 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -449,7 +449,7 @@ class ScaledLinear(nn.Linear): fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) + self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) def get_weight(self): return self.weight * (self.weight_scale * self.scale_speed).exp() @@ -485,7 +485,7 @@ class ScaledConv1d(nn.Conv1d): fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) + self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) def get_weight(self): @@ -527,7 +527,7 @@ class ScaledConv2d(nn.Conv2d): fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) + self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) def get_weight(self): diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 0832d9385..b14e83780 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -327,7 +327,6 @@ class RelPositionalEncoding(torch.nn.Module): """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model - self.xscale = math.sqrt(self.d_model) self.dropout = torch.nn.Dropout(p=dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) @@ -379,7 +378,6 @@ class RelPositionalEncoding(torch.nn.Module): """ self.extend_pe(x) - x = x * self.xscale pos_emb = self.pe[ :, self.pe.size(1) // 2