mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
pad 1 with torch.nn.functional.pad
This commit is contained in:
parent
cbc6f01861
commit
7df98eb000
@ -207,7 +207,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
# kernel size=21, self.conv1d_channels=768
|
# kernel size=21, self.conv1d_channels=768
|
||||||
# kernel size=5, self.conv1d_channels=1024
|
# kernel size=5, self.conv1d_channels=1024
|
||||||
self.linear1 = nn.Linear(512, self.in_conv1d_channels)
|
self.linear1 = nn.Linear(512, self.in_conv1d_channels)
|
||||||
self.conv1d_abs = Conv1dAbs(self.in_conv1d_channels, self.out_conv1d_channels, kernel_size=self.kernel_size, padding=self.padding, padding_mode="replicate")
|
self.conv1d_abs = Conv1dAbs(self.in_conv1d_channels, self.out_conv1d_channels, kernel_size=self.kernel_size, padding=self.padding, padding_mode="constant")
|
||||||
self.linear2 = nn.Linear(self.out_conv1d_channels, 512)
|
self.linear2 = nn.Linear(self.out_conv1d_channels, 512)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -333,7 +333,6 @@ class ConformerEncoder(nn.TransformerEncoder):
|
|||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
src: the sequence to the encoder (required).
|
src: the sequence to the encoder (required).
|
||||||
pos_emb: Positional embedding tensor (required).
|
pos_emb: Positional embedding tensor (required).
|
||||||
|
@ -153,10 +153,6 @@ class _ConvNd(Module):
|
|||||||
if not hasattr(self, "padding_mode"):
|
if not hasattr(self, "padding_mode"):
|
||||||
self.padding_mode = "zeros"
|
self.padding_mode = "zeros"
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
m = nn.Tanh()
|
|
||||||
|
|
||||||
def padding(input, padding_length):
|
def padding(input, padding_length):
|
||||||
# input shape : (B, T, D)
|
# input shape : (B, T, D)
|
||||||
device = input.device
|
device = input.device
|
||||||
@ -201,12 +197,13 @@ class Conv1dAbs(_ConvNd):
|
|||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
if self.padding_mode != "zeros":
|
if self.padding_mode != "zeros":
|
||||||
return F.conv1d(
|
return F.conv1d(
|
||||||
# F.pad(
|
F.pad(
|
||||||
# input,
|
input,
|
||||||
# self._reversed_padding_repeated_twice,
|
(0, 0, self.padding, self.padding),
|
||||||
# mode=self.padding_mode,
|
self.padding_mode,
|
||||||
# ),
|
1,
|
||||||
padding(input, self.padding),
|
).permute(0, 2, 1),
|
||||||
|
#padding(input, self.padding),
|
||||||
torch.exp(self.weight),
|
torch.exp(self.weight),
|
||||||
torch.exp(self.bias),
|
torch.exp(self.bias),
|
||||||
self.stride,
|
self.stride,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user