pad 1 with torch.nn.functional.pad

This commit is contained in:
Mingshuang Luo 2022-01-26 13:00:18 +08:00
parent cbc6f01861
commit 7df98eb000
2 changed files with 8 additions and 12 deletions

View File

@ -207,7 +207,7 @@ class ConformerEncoderLayer(nn.Module):
# kernel size=21, self.conv1d_channels=768
# kernel size=5, self.conv1d_channels=1024
self.linear1 = nn.Linear(512, self.in_conv1d_channels)
self.conv1d_abs = Conv1dAbs(self.in_conv1d_channels, self.out_conv1d_channels, kernel_size=self.kernel_size, padding=self.padding, padding_mode="replicate")
self.conv1d_abs = Conv1dAbs(self.in_conv1d_channels, self.out_conv1d_channels, kernel_size=self.kernel_size, padding=self.padding, padding_mode="constant")
self.linear2 = nn.Linear(self.out_conv1d_channels, 512)
def forward(
@ -333,7 +333,6 @@ class ConformerEncoder(nn.TransformerEncoder):
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).

View File

@ -153,10 +153,6 @@ class _ConvNd(Module):
if not hasattr(self, "padding_mode"):
self.padding_mode = "zeros"
import torch
import torch.nn as nn
m = nn.Tanh()
def padding(input, padding_length):
# input shape : (B, T, D)
device = input.device
@ -201,12 +197,13 @@ class Conv1dAbs(_ConvNd):
def forward(self, input: Tensor) -> Tensor:
if self.padding_mode != "zeros":
return F.conv1d(
# F.pad(
# input,
# self._reversed_padding_repeated_twice,
# mode=self.padding_mode,
# ),
padding(input, self.padding),
F.pad(
input,
(0, 0, self.padding, self.padding),
self.padding_mode,
1,
).permute(0, 2, 1),
#padding(input, self.padding),
torch.exp(self.weight),
torch.exp(self.bias),
self.stride,