mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Cosmetic improvements to convolution module; enable more stats.
This commit is contained in:
parent
6845da4351
commit
f4f3d057e7
@ -1579,8 +1579,11 @@ class ConvolutionModule(nn.Module):
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
||||
bottleneck_dim = channels
|
||||
|
||||
|
||||
self.in_proj = LinearWithAuxLoss(
|
||||
channels, 2 * channels,
|
||||
channels, 2 * bottleneck_dim,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()
|
||||
)
|
||||
|
||||
@ -1599,35 +1602,38 @@ class ConvolutionModule(nn.Module):
|
||||
# it will be in a better position to start learning something, i.e. to latch onto
|
||||
# the correct range.
|
||||
self.balancer1 = ActivationBalancer(
|
||||
2 * channels, channel_dim=-1,
|
||||
bottleneck_dim, channel_dim=-1,
|
||||
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
||||
max_positive=1.0,
|
||||
min_abs=1.5,
|
||||
max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
|
||||
)
|
||||
|
||||
self.pre_sigmoid = Identity() # before sigmoid; for diagnostics.
|
||||
self.activation1 = Identity() # for diagnostics
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
self.activation2 = Identity() # for diagnostics
|
||||
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
bottleneck_dim,
|
||||
bottleneck_dim,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=channels,
|
||||
groups=bottleneck_dim,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.balancer2 = ActivationBalancer(
|
||||
channels, channel_dim=1,
|
||||
bottleneck_dim, channel_dim=1,
|
||||
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
||||
max_positive=1.0,
|
||||
min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 1.0)),
|
||||
max_abs=10.0,
|
||||
)
|
||||
|
||||
self.activation = SwooshR()
|
||||
self.activation3 = SwooshR()
|
||||
|
||||
self.whiten = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
@ -1635,7 +1641,7 @@ class ConvolutionModule(nn.Module):
|
||||
grad_scale=0.01)
|
||||
|
||||
self.out_proj = LinearWithAuxLoss(
|
||||
channels, channels,
|
||||
bottleneck_dim, channels,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out(),
|
||||
initial_scale=0.05,
|
||||
)
|
||||
@ -1658,12 +1664,13 @@ class ConvolutionModule(nn.Module):
|
||||
"""
|
||||
|
||||
x = self.in_proj(x) # (time, batch, 2*channels)
|
||||
x = self.balancer1(x)
|
||||
|
||||
x, s = x.chunk(2, dim=-1)
|
||||
s = self.pre_sigmoid(s)
|
||||
s = self.balancer1(s)
|
||||
s = self.sigmoid(s)
|
||||
x = self.activation1(x) # identity.
|
||||
x = x * s
|
||||
x = self.activation2(x) # identity
|
||||
|
||||
# (time, batch, channels)
|
||||
|
||||
@ -1679,7 +1686,7 @@ class ConvolutionModule(nn.Module):
|
||||
x = self.balancer2(x)
|
||||
x = x.permute(2, 0, 1) # (time, batch, channels)
|
||||
|
||||
x = self.activation(x)
|
||||
x = self.activation3(x)
|
||||
x = self.whiten(x) # (time, batch, channels)
|
||||
x = self.out_proj(x) # (time, batch, channels)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user