pad 1 with torch.nn.functional.pad

This commit is contained in:
Mingshuang Luo 2022-01-26 13:00:18 +08:00
parent cbc6f01861
commit 7df98eb000
2 changed files with 8 additions and 12 deletions

View File

@ -207,7 +207,7 @@ class ConformerEncoderLayer(nn.Module):
# kernel size=21, self.conv1d_channels=768 # kernel size=21, self.conv1d_channels=768
# kernel size=5, self.conv1d_channels=1024 # kernel size=5, self.conv1d_channels=1024
self.linear1 = nn.Linear(512, self.in_conv1d_channels) self.linear1 = nn.Linear(512, self.in_conv1d_channels)
self.conv1d_abs = Conv1dAbs(self.in_conv1d_channels, self.out_conv1d_channels, kernel_size=self.kernel_size, padding=self.padding, padding_mode="replicate") self.conv1d_abs = Conv1dAbs(self.in_conv1d_channels, self.out_conv1d_channels, kernel_size=self.kernel_size, padding=self.padding, padding_mode="constant")
self.linear2 = nn.Linear(self.out_conv1d_channels, 512) self.linear2 = nn.Linear(self.out_conv1d_channels, 512)
def forward( def forward(
@ -333,7 +333,6 @@ class ConformerEncoder(nn.TransformerEncoder):
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
r"""Pass the input through the encoder layers in turn. r"""Pass the input through the encoder layers in turn.
Args: Args:
src: the sequence to the encoder (required). src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required). pos_emb: Positional embedding tensor (required).

View File

@ -153,10 +153,6 @@ class _ConvNd(Module):
if not hasattr(self, "padding_mode"): if not hasattr(self, "padding_mode"):
self.padding_mode = "zeros" self.padding_mode = "zeros"
import torch
import torch.nn as nn
m = nn.Tanh()
def padding(input, padding_length): def padding(input, padding_length):
# input shape : (B, T, D) # input shape : (B, T, D)
device = input.device device = input.device
@ -201,12 +197,13 @@ class Conv1dAbs(_ConvNd):
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
if self.padding_mode != "zeros": if self.padding_mode != "zeros":
return F.conv1d( return F.conv1d(
# F.pad( F.pad(
# input, input,
# self._reversed_padding_repeated_twice, (0, 0, self.padding, self.padding),
# mode=self.padding_mode, self.padding_mode,
# ), 1,
padding(input, self.padding), ).permute(0, 2, 1),
#padding(input, self.padding),
torch.exp(self.weight), torch.exp(self.weight),
torch.exp(self.bias), torch.exp(self.bias),
self.stride, self.stride,