diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 50a46ae28..2d070d937 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -207,7 +207,7 @@ class ConformerEncoderLayer(nn.Module): # kernel size=21, self.conv1d_channels=768 # kernel size=5, self.conv1d_channels=1024 self.linear1 = nn.Linear(512, self.in_conv1d_channels) - self.conv1d_abs = Conv1dAbs(self.in_conv1d_channels, self.out_conv1d_channels, kernel_size=self.kernel_size, padding=self.padding, padding_mode="replicate") + self.conv1d_abs = Conv1dAbs(self.in_conv1d_channels, self.out_conv1d_channels, kernel_size=self.kernel_size, padding=self.padding, padding_mode="constant") self.linear2 = nn.Linear(self.out_conv1d_channels, 512) def forward( @@ -333,7 +333,6 @@ class ConformerEncoder(nn.TransformerEncoder): src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Pass the input through the encoder layers in turn. - Args: src: the sequence to the encoder (required). pos_emb: Positional embedding tensor (required). diff --git a/egs/librispeech/ASR/conformer_ctc/conv1d_abs_attention.py b/egs/librispeech/ASR/conformer_ctc/conv1d_abs_attention.py index 84ff0e9fb..e3a35779c 100644 --- a/egs/librispeech/ASR/conformer_ctc/conv1d_abs_attention.py +++ b/egs/librispeech/ASR/conformer_ctc/conv1d_abs_attention.py @@ -153,10 +153,6 @@ class _ConvNd(Module): if not hasattr(self, "padding_mode"): self.padding_mode = "zeros" -import torch -import torch.nn as nn -m = nn.Tanh() - def padding(input, padding_length): # input shape : (B, T, D) device = input.device @@ -201,12 +197,13 @@ class Conv1dAbs(_ConvNd): def forward(self, input: Tensor) -> Tensor: if self.padding_mode != "zeros": return F.conv1d( - # F.pad( - # input, - # self._reversed_padding_repeated_twice, - # mode=self.padding_mode, - # ), - padding(input, self.padding), + F.pad( + input, + (0, 0, self.padding, self.padding), + self.padding_mode, + 1, + ).permute(0, 2, 1), + #padding(input, self.padding), torch.exp(self.weight), torch.exp(self.bias), self.stride,