diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 5a4dfa686..90552545b 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -181,6 +181,7 @@ class ConformerEncoderLayer(nn.Module): # define layernorm for conv1d_abs self.norm_conv_abs = nn.LayerNorm(d_model) + self.layernorm = nn.LayerNorm(d_model) self.ff_scale = 0.5 @@ -193,14 +194,17 @@ class ConformerEncoderLayer(nn.Module): self.normalize_before = normalize_before - self.in_channel = 1024 - self.out_channel = 64 self.kernel_size = 21 - self.padding = 10 - self.linear1 = nn.Linear(512, self.in_channel) - self.conv1d_abs = Conv1dAbs(self.in_channel, self.out_channel, kernel_size=self.kernel_size, padding=self.padding) - self.activation = nn.ReLU() - self.linear2 = nn.Linear(self.out_channel, 512) + self.padding = int((self.kernel_size - 1) / 2) + self.conv1d_channels = 768 + self.linear1 = nn.Linear(512, self.conv1d_channels) + self.conv1d_abs = Conv1dAbs( + self.conv1d_channels, + self.conv1d_channels, + kernel_size=self.kernel_size, + padding=self.padding, + ) + self.linear2 = nn.Linear(self.conv1d_channels, 512) def forward( self, @@ -236,17 +240,33 @@ class ConformerEncoderLayer(nn.Module): if not self.normalize_before: src = self.norm_ff_macaron(src) - # modified-attention module + inf = torch.tensor(float("inf"), device=src.device) + + def check_inf(x): + if x.max() == inf: + print("Error: inf found: ", x) + assert 0 + + # modified-attention module + residual = src if self.normalize_before: src = self.norm_conv_abs(src) - src = self.linear1(src) - src = torch.exp(src.clamp(max=75)) + + src = self.linear1(src * 0.25) + src = torch.exp(src.clamp(min=-75, max=75)) + check_inf(src) src = src.permute(1, 2, 0) src = self.conv1d_abs(src) / self.kernel_size - src = self.activation(src).permute(2, 0, 1) - src = torch.log(src) + check_inf(src) + src = src.permute(2, 0, 1) + src = torch.log(src.clamp(min=1e-20)) + check_inf(src) src = self.linear2(src) + src = self.layernorm(src) + # multipy the output by 0.5 later. + # do a comparison. + src = residual + self.dropout(src) if not self.normalize_before: src = self.norm_conv_abs(src) @@ -387,8 +407,8 @@ class RelPositionalEncoding(torch.nn.Module): pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) # Reserve the order of positive indices and concat both positive and - # negative indices. This is used to support the shifting trick - # as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + # negative indices. This is used to support the shifting trick as in "T + # ransformer-XL:Attentive Language Models Beyond a Fixed-Length Context" pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) pe_negative = pe_negative[1:].unsqueeze(0) pe = torch.cat([pe_positive, pe_negative], dim=1)