mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Use LinearWithAuxLoss in more places.
This commit is contained in:
parent
d9c7e4f216
commit
281b54e7bf
@ -1506,23 +1506,20 @@ class ConvolutionModule(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels: int, kernel_size: int, bias: bool = True
|
||||
self, channels: int, kernel_size: int,
|
||||
) -> None:
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
||||
self.pointwise_conv1 = nn.Conv1d(
|
||||
channels,
|
||||
2 * channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
self.in_proj = LinearWithAuxLoss(
|
||||
channels, 2 * channels,
|
||||
aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01))
|
||||
)
|
||||
|
||||
# after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu).
|
||||
|
||||
# after in_proj we put x through a gated linear unit (nn.functional.glu).
|
||||
# For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
|
||||
# but sometimes, for some reason, for layer 0 the rms ends up being very large,
|
||||
# between 50 and 100 for different channels. This will cause very peaky and
|
||||
@ -1536,8 +1533,8 @@ class ConvolutionModule(nn.Module):
|
||||
# it will be in a better position to start learning something, i.e. to latch onto
|
||||
# the correct range.
|
||||
self.deriv_balancer1 = ActivationBalancer(
|
||||
2 * channels,
|
||||
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||
2 * channels, channel_dim=-1,
|
||||
max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||
)
|
||||
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
@ -1547,7 +1544,7 @@ class ConvolutionModule(nn.Module):
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.deriv_balancer2 = ActivationBalancer(
|
||||
@ -1563,13 +1560,9 @@ class ConvolutionModule(nn.Module):
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
self.pointwise_conv2 = ScaledConv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
self.out_proj = LinearWithAuxLoss(
|
||||
channels, channels,
|
||||
aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01)),
|
||||
initial_scale=0.05,
|
||||
)
|
||||
|
||||
@ -1589,15 +1582,14 @@ class ConvolutionModule(nn.Module):
|
||||
Tensor: Output tensor (#time, batch, channels).
|
||||
|
||||
"""
|
||||
|
||||
x = self.in_proj(x) # (time, batch, 2*channels)
|
||||
x = self.deriv_balancer1(x)
|
||||
x = nn.functional.glu(x, dim=-1) # (time, batch, channels)
|
||||
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
|
||||
# GLU mechanism
|
||||
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
||||
|
||||
x = self.deriv_balancer1(x)
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||
|
||||
if src_key_padding_mask is not None:
|
||||
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||
|
||||
@ -1605,15 +1597,12 @@ class ConvolutionModule(nn.Module):
|
||||
x = self.depthwise_conv(x)
|
||||
|
||||
x = self.deriv_balancer2(x)
|
||||
x = x.permute(2, 0, 1) # (time, batch, channels)
|
||||
|
||||
x = self.activation(x)
|
||||
x = self.whiten(x) # (time, batch, channels)
|
||||
x = self.out_proj(x) # (time, batch, channels)
|
||||
|
||||
x = x.transpose(1, 2)
|
||||
x = self.whiten(x) # (batch, time, channel)
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||
|
||||
x = x.permute(2, 0, 1) # (time, batch, channel)
|
||||
return x
|
||||
|
||||
|
||||
@ -1732,7 +1721,9 @@ class Conv2dSubsampling(nn.Module):
|
||||
self.squeeze_excite = SqueezeExcite1d(out_height * layer3_channels,
|
||||
bottleneck_channels)
|
||||
|
||||
self.out = ScaledLinear(out_height * layer3_channels, out_channels)
|
||||
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
|
||||
aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01)))
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user