diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 4a68355e2..75beeb529 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -199,15 +199,16 @@ class ConformerEncoderLayer(nn.Module): self.kernel_size = 31 self.padding = int((self.kernel_size - 1) / 2) - self.conv1d_channels = 768 - self.linear1 = nn.Linear(512, self.conv1d_channels) + self.in_conv1d_channels = 768 + self.out_conv1d_channels = 768 + self.linear1 = nn.Linear(512, self.in_conv1d_channels) self.conv1d_abs = Conv1dAbs( - self.conv1d_channels, - self.conv1d_channels, + self.in_conv1d_channels, + self.out_conv1d_channels, kernel_size=self.kernel_size, padding=self.padding, ) - self.linear2 = nn.Linear(self.conv1d_channels, 512) + self.linear2 = nn.Linear(self.out_conv1d_channels, 512) def forward( self, @@ -260,27 +261,20 @@ class ConformerEncoderLayer(nn.Module): src = self.norm_mha(src) # conv1dabs modified attention module - inf = torch.tensor(float("inf"), device=src.device) - def check_inf(x): - if x.max() == inf: - print("Error: inf found: ", x) - assert 0 residual = src if self.normalize_before: src = self.norm_conv_abs(src) - src = self.linear1(src * 0.25) + + # src = self.linear1(src * 0.25) + src = 0.01 * self.linear1(src * 0.25) src = torch.exp(src.clamp(min=-75, max=75)) - check_inf(src) src = src.permute(1, 2, 0) src = self.conv1d_abs(src) / self.kernel_size - check_inf(src) src = src.permute(2, 0, 1) src = torch.log(src.clamp(min=1e-20)) - check_inf(src) src = self.linear2(src) - src = self.layernorm(src) - # multipy the output by 0.5 later. - # do a comparison. + src = 0.25 * self.layernorm(src) + src = residual + self.dropout(src) if not self.normalize_before: src = self.norm_conv_abs(src) @@ -871,6 +865,7 @@ class RelPositionMultiheadAttention(nn.Module): else: return attn_output, None + class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model. Modified from