mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
pad 1 with torch.nn.functional.pad
This commit is contained in:
parent
cbc6f01861
commit
7df98eb000
@ -207,7 +207,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
# kernel size=21, self.conv1d_channels=768
|
||||
# kernel size=5, self.conv1d_channels=1024
|
||||
self.linear1 = nn.Linear(512, self.in_conv1d_channels)
|
||||
self.conv1d_abs = Conv1dAbs(self.in_conv1d_channels, self.out_conv1d_channels, kernel_size=self.kernel_size, padding=self.padding, padding_mode="replicate")
|
||||
self.conv1d_abs = Conv1dAbs(self.in_conv1d_channels, self.out_conv1d_channels, kernel_size=self.kernel_size, padding=self.padding, padding_mode="constant")
|
||||
self.linear2 = nn.Linear(self.out_conv1d_channels, 512)
|
||||
|
||||
def forward(
|
||||
@ -333,7 +333,6 @@ class ConformerEncoder(nn.TransformerEncoder):
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
|
@ -153,10 +153,6 @@ class _ConvNd(Module):
|
||||
if not hasattr(self, "padding_mode"):
|
||||
self.padding_mode = "zeros"
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
m = nn.Tanh()
|
||||
|
||||
def padding(input, padding_length):
|
||||
# input shape : (B, T, D)
|
||||
device = input.device
|
||||
@ -201,12 +197,13 @@ class Conv1dAbs(_ConvNd):
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
if self.padding_mode != "zeros":
|
||||
return F.conv1d(
|
||||
# F.pad(
|
||||
# input,
|
||||
# self._reversed_padding_repeated_twice,
|
||||
# mode=self.padding_mode,
|
||||
# ),
|
||||
padding(input, self.padding),
|
||||
F.pad(
|
||||
input,
|
||||
(0, 0, self.padding, self.padding),
|
||||
self.padding_mode,
|
||||
1,
|
||||
).permute(0, 2, 1),
|
||||
#padding(input, self.padding),
|
||||
torch.exp(self.weight),
|
||||
torch.exp(self.bias),
|
||||
self.stride,
|
||||
|
Loading…
x
Reference in New Issue
Block a user