Use LinearWithAuxLoss in more places.

This commit is contained in:
Daniel Povey 2022-11-26 12:25:22 +08:00
parent d9c7e4f216
commit 281b54e7bf

View File

@ -1506,23 +1506,20 @@ class ConvolutionModule(nn.Module):
"""
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
self, channels: int, kernel_size: int,
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
self.in_proj = LinearWithAuxLoss(
channels, 2 * channels,
aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01))
)
# after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu).
# after in_proj we put x through a gated linear unit (nn.functional.glu).
# For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
# but sometimes, for some reason, for layer 0 the rms ends up being very large,
# between 50 and 100 for different channels. This will cause very peaky and
@ -1536,8 +1533,8 @@ class ConvolutionModule(nn.Module):
# it will be in a better position to start learning something, i.e. to latch onto
# the correct range.
self.deriv_balancer1 = ActivationBalancer(
2 * channels,
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
2 * channels, channel_dim=-1,
max_abs=10.0, min_positive=0.05, max_positive=1.0
)
self.depthwise_conv = nn.Conv1d(
@ -1547,7 +1544,7 @@ class ConvolutionModule(nn.Module):
stride=1,
padding=(kernel_size - 1) // 2,
groups=channels,
bias=bias,
bias=True,
)
self.deriv_balancer2 = ActivationBalancer(
@ -1563,13 +1560,9 @@ class ConvolutionModule(nn.Module):
prob=(0.025, 0.25),
grad_scale=0.01)
self.pointwise_conv2 = ScaledConv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
self.out_proj = LinearWithAuxLoss(
channels, channels,
aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01)),
initial_scale=0.05,
)
@ -1589,15 +1582,14 @@ class ConvolutionModule(nn.Module):
Tensor: Output tensor (#time, batch, channels).
"""
x = self.in_proj(x) # (time, batch, 2*channels)
x = self.deriv_balancer1(x)
x = nn.functional.glu(x, dim=-1) # (time, batch, channels)
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
x = self.deriv_balancer1(x)
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
@ -1605,15 +1597,12 @@ class ConvolutionModule(nn.Module):
x = self.depthwise_conv(x)
x = self.deriv_balancer2(x)
x = x.permute(2, 0, 1) # (time, batch, channels)
x = self.activation(x)
x = self.whiten(x) # (time, batch, channels)
x = self.out_proj(x) # (time, batch, channels)
x = x.transpose(1, 2)
x = self.whiten(x) # (batch, time, channel)
x = x.transpose(1, 2)
x = self.pointwise_conv2(x) # (batch, channel, time)
x = x.permute(2, 0, 1) # (time, batch, channel)
return x
@ -1732,7 +1721,9 @@ class Conv2dSubsampling(nn.Module):
self.squeeze_excite = SqueezeExcite1d(out_height * layer3_channels,
bottleneck_channels)
self.out = ScaledLinear(out_height * layer3_channels, out_channels)
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01)))
self.dropout = nn.Dropout(dropout)