mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Cosmetic improvements to convolution module; enable more stats.
This commit is contained in:
parent
6845da4351
commit
f4f3d057e7
@ -1579,8 +1579,11 @@ class ConvolutionModule(nn.Module):
|
|||||||
# kernerl_size should be a odd number for 'SAME' padding
|
# kernerl_size should be a odd number for 'SAME' padding
|
||||||
assert (kernel_size - 1) % 2 == 0
|
assert (kernel_size - 1) % 2 == 0
|
||||||
|
|
||||||
|
bottleneck_dim = channels
|
||||||
|
|
||||||
|
|
||||||
self.in_proj = LinearWithAuxLoss(
|
self.in_proj = LinearWithAuxLoss(
|
||||||
channels, 2 * channels,
|
channels, 2 * bottleneck_dim,
|
||||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()
|
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1599,35 +1602,38 @@ class ConvolutionModule(nn.Module):
|
|||||||
# it will be in a better position to start learning something, i.e. to latch onto
|
# it will be in a better position to start learning something, i.e. to latch onto
|
||||||
# the correct range.
|
# the correct range.
|
||||||
self.balancer1 = ActivationBalancer(
|
self.balancer1 = ActivationBalancer(
|
||||||
2 * channels, channel_dim=-1,
|
bottleneck_dim, channel_dim=-1,
|
||||||
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
||||||
max_positive=1.0,
|
max_positive=1.0,
|
||||||
min_abs=1.5,
|
min_abs=1.5,
|
||||||
max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
|
max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.pre_sigmoid = Identity() # before sigmoid; for diagnostics.
|
self.activation1 = Identity() # for diagnostics
|
||||||
|
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
self.activation2 = Identity() # for diagnostics
|
||||||
|
|
||||||
self.depthwise_conv = nn.Conv1d(
|
self.depthwise_conv = nn.Conv1d(
|
||||||
channels,
|
bottleneck_dim,
|
||||||
channels,
|
bottleneck_dim,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=(kernel_size - 1) // 2,
|
padding=(kernel_size - 1) // 2,
|
||||||
groups=channels,
|
groups=bottleneck_dim,
|
||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.balancer2 = ActivationBalancer(
|
self.balancer2 = ActivationBalancer(
|
||||||
channels, channel_dim=1,
|
bottleneck_dim, channel_dim=1,
|
||||||
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
||||||
max_positive=1.0,
|
max_positive=1.0,
|
||||||
min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 1.0)),
|
min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 1.0)),
|
||||||
max_abs=10.0,
|
max_abs=10.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.activation = SwooshR()
|
self.activation3 = SwooshR()
|
||||||
|
|
||||||
self.whiten = Whiten(num_groups=1,
|
self.whiten = Whiten(num_groups=1,
|
||||||
whitening_limit=_whitening_schedule(7.5),
|
whitening_limit=_whitening_schedule(7.5),
|
||||||
@ -1635,7 +1641,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
self.out_proj = LinearWithAuxLoss(
|
self.out_proj = LinearWithAuxLoss(
|
||||||
channels, channels,
|
bottleneck_dim, channels,
|
||||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out(),
|
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out(),
|
||||||
initial_scale=0.05,
|
initial_scale=0.05,
|
||||||
)
|
)
|
||||||
@ -1658,12 +1664,13 @@ class ConvolutionModule(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
x = self.in_proj(x) # (time, batch, 2*channels)
|
x = self.in_proj(x) # (time, batch, 2*channels)
|
||||||
x = self.balancer1(x)
|
|
||||||
|
|
||||||
x, s = x.chunk(2, dim=-1)
|
x, s = x.chunk(2, dim=-1)
|
||||||
s = self.pre_sigmoid(s)
|
s = self.balancer1(s)
|
||||||
s = self.sigmoid(s)
|
s = self.sigmoid(s)
|
||||||
|
x = self.activation1(x) # identity.
|
||||||
x = x * s
|
x = x * s
|
||||||
|
x = self.activation2(x) # identity
|
||||||
|
|
||||||
# (time, batch, channels)
|
# (time, batch, channels)
|
||||||
|
|
||||||
@ -1679,7 +1686,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
x = self.balancer2(x)
|
x = self.balancer2(x)
|
||||||
x = x.permute(2, 0, 1) # (time, batch, channels)
|
x = x.permute(2, 0, 1) # (time, batch, channels)
|
||||||
|
|
||||||
x = self.activation(x)
|
x = self.activation3(x)
|
||||||
x = self.whiten(x) # (time, batch, channels)
|
x = self.whiten(x) # (time, batch, channels)
|
||||||
x = self.out_proj(x) # (time, batch, channels)
|
x = self.out_proj(x) # (time, batch, channels)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user