diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 91f8cf694..5a4dfa686 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -193,10 +193,14 @@ class ConformerEncoderLayer(nn.Module): self.normalize_before = normalize_before - self.linear1 = nn.Linear(512, 1024) - self.conv1d_abs = Conv1dAbs(1024, 64, kernel_size=21, padding=10) + self.in_channel = 1024 + self.out_channel = 64 + self.kernel_size = 21 + self.padding = 10 + self.linear1 = nn.Linear(512, self.in_channel) + self.conv1d_abs = Conv1dAbs(self.in_channel, self.out_channel, kernel_size=self.kernel_size, padding=self.padding) self.activation = nn.ReLU() - self.linear2 = nn.Linear(64, 512) + self.linear2 = nn.Linear(self.out_channel, 512) def forward( self, @@ -239,7 +243,7 @@ class ConformerEncoderLayer(nn.Module): src = self.linear1(src) src = torch.exp(src.clamp(max=75)) src = src.permute(1, 2, 0) - src = self.conv1d_abs(src) + src = self.conv1d_abs(src) / self.kernel_size src = self.activation(src).permute(2, 0, 1) src = torch.log(src) src = self.linear2(src)