Cosmetic improvements to convolution module; enable more stats.

This commit is contained in:
Daniel Povey 2022-12-08 18:27:01 +08:00
parent 6845da4351
commit f4f3d057e7

View File

@ -1579,8 +1579,11 @@ class ConvolutionModule(nn.Module):
# kernerl_size should be a odd number for 'SAME' padding # kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0 assert (kernel_size - 1) % 2 == 0
bottleneck_dim = channels
self.in_proj = LinearWithAuxLoss( self.in_proj = LinearWithAuxLoss(
channels, 2 * channels, channels, 2 * bottleneck_dim,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in() aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()
) )
@ -1599,35 +1602,38 @@ class ConvolutionModule(nn.Module):
# it will be in a better position to start learning something, i.e. to latch onto # it will be in a better position to start learning something, i.e. to latch onto
# the correct range. # the correct range.
self.balancer1 = ActivationBalancer( self.balancer1 = ActivationBalancer(
2 * channels, channel_dim=-1, bottleneck_dim, channel_dim=-1,
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
max_positive=1.0, max_positive=1.0,
min_abs=1.5, min_abs=1.5,
max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
) )
self.pre_sigmoid = Identity() # before sigmoid; for diagnostics. self.activation1 = Identity() # for diagnostics
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
self.activation2 = Identity() # for diagnostics
self.depthwise_conv = nn.Conv1d( self.depthwise_conv = nn.Conv1d(
channels, bottleneck_dim,
channels, bottleneck_dim,
kernel_size, kernel_size,
stride=1, stride=1,
padding=(kernel_size - 1) // 2, padding=(kernel_size - 1) // 2,
groups=channels, groups=bottleneck_dim,
bias=True, bias=True,
) )
self.balancer2 = ActivationBalancer( self.balancer2 = ActivationBalancer(
channels, channel_dim=1, bottleneck_dim, channel_dim=1,
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
max_positive=1.0, max_positive=1.0,
min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 1.0)), min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 1.0)),
max_abs=10.0, max_abs=10.0,
) )
self.activation = SwooshR() self.activation3 = SwooshR()
self.whiten = Whiten(num_groups=1, self.whiten = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(7.5), whitening_limit=_whitening_schedule(7.5),
@ -1635,7 +1641,7 @@ class ConvolutionModule(nn.Module):
grad_scale=0.01) grad_scale=0.01)
self.out_proj = LinearWithAuxLoss( self.out_proj = LinearWithAuxLoss(
channels, channels, bottleneck_dim, channels,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out(), aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out(),
initial_scale=0.05, initial_scale=0.05,
) )
@ -1658,12 +1664,13 @@ class ConvolutionModule(nn.Module):
""" """
x = self.in_proj(x) # (time, batch, 2*channels) x = self.in_proj(x) # (time, batch, 2*channels)
x = self.balancer1(x)
x, s = x.chunk(2, dim=-1) x, s = x.chunk(2, dim=-1)
s = self.pre_sigmoid(s) s = self.balancer1(s)
s = self.sigmoid(s) s = self.sigmoid(s)
x = self.activation1(x) # identity.
x = x * s x = x * s
x = self.activation2(x) # identity
# (time, batch, channels) # (time, batch, channels)
@ -1679,7 +1686,7 @@ class ConvolutionModule(nn.Module):
x = self.balancer2(x) x = self.balancer2(x)
x = x.permute(2, 0, 1) # (time, batch, channels) x = x.permute(2, 0, 1) # (time, batch, channels)
x = self.activation(x) x = self.activation3(x)
x = self.whiten(x) # (time, batch, channels) x = self.whiten(x) # (time, batch, channels)
x = self.out_proj(x) # (time, batch, channels) x = self.out_proj(x) # (time, batch, channels)