mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Integrate LinearWithAuxLoss into SqueezeExcite1d
This commit is contained in:
commit
8f1ef60951
@ -1613,6 +1613,49 @@ class ConvolutionModule(nn.Module):
|
|||||||
x = x.permute(2, 0, 1) # (time, batch, channel)
|
x = x.permute(2, 0, 1) # (time, batch, channel)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SqueezeExcite1d(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
channels: int,
|
||||||
|
bottleneck_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.to_bottleneck_proj = LinearWithAuxLoss(channels,
|
||||||
|
bottleneck_channels)
|
||||||
|
|
||||||
|
self.bottleneck_activation = TanSwish()
|
||||||
|
self.from_bottleneck_proj = nn.Linear(bottleneck_channels,
|
||||||
|
channels)
|
||||||
|
|
||||||
|
self.balancer = ActivationBalancer(
|
||||||
|
channels, channel_dim=-1,
|
||||||
|
min_abs=0.05,
|
||||||
|
max_abs=ScheduledFloat((0.0, 0.2),
|
||||||
|
(4000.0, 2.0),
|
||||||
|
(10000.0, 10.0),
|
||||||
|
default=1.0),
|
||||||
|
max_factor=0.02,
|
||||||
|
min_prob=0.1,
|
||||||
|
)
|
||||||
|
self.activation = nn.Sigmoid()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
"""
|
||||||
|
x: a Tensor of shape (batch_size, T, channels).
|
||||||
|
Returns: something with the same shape as x.
|
||||||
|
"""
|
||||||
|
# project before mean, needed for LinearWithAuxLoss (or, at least, better)
|
||||||
|
bottleneck = self.to_bottleneck_proj(x)
|
||||||
|
# would replace this mean with cumsum for a causal model.
|
||||||
|
bottleneck = bottleneck.mean(dim=1, keepdim=True)
|
||||||
|
bottleneck = self.bottleneck_activation(bottleneck)
|
||||||
|
scale = self.from_bottleneck_proj(bottleneck)
|
||||||
|
scale = self.balancer(scale)
|
||||||
|
scale = self.activation(scale)
|
||||||
|
return x * scale
|
||||||
|
|
||||||
|
|
||||||
class Conv2dSubsampling(nn.Module):
|
class Conv2dSubsampling(nn.Module):
|
||||||
"""Convolutional 2D subsampling (to 1/2 length).
|
"""Convolutional 2D subsampling (to 1/2 length).
|
||||||
|
|
||||||
@ -1631,6 +1674,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
layer1_channels: int = 8,
|
layer1_channels: int = 8,
|
||||||
layer2_channels: int = 32,
|
layer2_channels: int = 32,
|
||||||
layer3_channels: int = 128,
|
layer3_channels: int = 128,
|
||||||
|
bottleneck_channels: int = 64,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -1644,6 +1688,8 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
Number of channels in layer1
|
Number of channels in layer1
|
||||||
layer1_channels:
|
layer1_channels:
|
||||||
Number of channels in layer2
|
Number of channels in layer2
|
||||||
|
bottleneck:
|
||||||
|
bottleneck dimension for 1d squeeze-excite
|
||||||
"""
|
"""
|
||||||
assert in_channels >= 7
|
assert in_channels >= 7
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -1679,6 +1725,10 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
)
|
)
|
||||||
out_height = (((in_channels - 1) // 2) - 1) // 2
|
out_height = (((in_channels - 1) // 2) - 1) // 2
|
||||||
|
|
||||||
|
self.squeeze_excite = SqueezeExcite1d(out_height * layer3_channels,
|
||||||
|
bottleneck_channels)
|
||||||
|
|
||||||
self.out = ScaledLinear(out_height * layer3_channels, out_channels)
|
self.out = ScaledLinear(out_height * layer3_channels, out_channels)
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
@ -1698,7 +1748,11 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||||
b, c, t, f = x.size()
|
b, c, t, f = x.size()
|
||||||
x = self.out(x.transpose(1, 2).reshape(b, t, c * f))
|
|
||||||
|
x = x.transpose(1, 2).reshape(b, t, c * f)
|
||||||
|
# now x: (N, ((T-1)//2 - 1))//2, out_height * layer3_channels))
|
||||||
|
x = self.squeeze_excite(x)
|
||||||
|
x = self.out(x)
|
||||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
return x
|
return x
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user