mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Try to make SmallConvolutionModule more efficient
This commit is contained in:
parent
167b58baa0
commit
ec8804283c
@ -1008,18 +1008,13 @@ class SmallConvolutionModule(nn.Module):
|
|||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
padding=kernel_size // 2)
|
padding=kernel_size // 2)
|
||||||
|
|
||||||
self.pointwise_conv1 = nn.Conv1d(
|
self.linear1 = nn.Linear(
|
||||||
channels,
|
channels, hidden_dim)
|
||||||
hidden_dim,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
# balancer and activation as tuned for ConvolutionModule.
|
# balancer and activation as tuned for ConvolutionModule.
|
||||||
|
|
||||||
self.balancer = Balancer(
|
self.balancer = Balancer(
|
||||||
hidden_dim, channel_dim=1,
|
hidden_dim, channel_dim=-1,
|
||||||
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
||||||
max_positive=1.0,
|
max_positive=1.0,
|
||||||
min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
|
min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
|
||||||
@ -1028,15 +1023,8 @@ class SmallConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
self.activation = SwooshR()
|
self.activation = SwooshR()
|
||||||
|
|
||||||
self.pointwise_conv2 = ScaledConv1d(
|
self.linear2 = ScaledLinear(hidden_dim, channels,
|
||||||
hidden_dim,
|
initial_scale=0.05)
|
||||||
channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
bias=True,
|
|
||||||
initial_scale=0.05,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@ -1057,16 +1045,13 @@ class SmallConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
if src_key_padding_mask is not None:
|
if src_key_padding_mask is not None:
|
||||||
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||||
|
|
||||||
x = self.depthwise_conv(x) # (batch, channels, time)
|
x = self.depthwise_conv(x) # (batch, channels, time)
|
||||||
x = self.pointwise_conv1(x) # (batch, hidden_dim, time)
|
x = x.permute(2, 0, 1) # (time, batch, channels)
|
||||||
|
x = self.linear1(x) # (time, batch, hidden_dim)
|
||||||
x = self.balancer(x)
|
x = self.balancer(x)
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
x = self.pointwise_conv2(x)
|
x = self.linear2(x) # (time, batch, channels)
|
||||||
|
return x
|
||||||
return x.permute(2, 0, 1)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CompactRelPositionalEncoding(torch.nn.Module):
|
class CompactRelPositionalEncoding(torch.nn.Module):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user