Changes to balancer schedules: start max_abs from 5.0 not 4.0, start min_positive from 0.1 more consistently; finish at 8k not 12k.

This commit is contained in:
Daniel Povey 2022-11-26 23:07:18 +08:00
parent 633b6785f1
commit c91014f104

View File

@ -1392,8 +1392,8 @@ class FeedforwardModule(nn.Module):
self.hidden_balancer = ActivationBalancer(feedforward_dim,
channel_dim=-1,
min_positive=ScheduledFloat((0.0, 0.1), (12000.0, 0.05)),
max_abs=ScheduledFloat((0.0, 4.0), (12000.0, 10.0), default=10),
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=10),
min_prob=0.25)
self.activation = DoubleSwish()
self.dropout = nn.Dropout(dropout)
@ -1435,10 +1435,10 @@ class NonlinAttentionModule(nn.Module):
# balancer that goes before the sigmoid.
self.balancer = ActivationBalancer(
channels // 2, channel_dim=-1,
min_positive=0.05, max_positive=1.0,
min_abs=0.2, max_abs=ScheduledFloat((0.0, 2.0),
(12000.0, 10.0),
default=1.0),
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
max_positive=1.0,
min_abs=0.2,
max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
)
self.sigmoid = nn.Sigmoid()
@ -1528,12 +1528,11 @@ class ConvolutionModule(nn.Module):
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
# it will be in a better position to start learning something, i.e. to latch onto
# the correct range.
self.deriv_balancer1 = ActivationBalancer(
self.balancer1 = ActivationBalancer(
2 * channels, channel_dim=-1,
max_abs=ScheduledFloat((0.0, 2.0),
(12000.0, 10.0),
default=1.0),
min_positive=0.05, max_positive=1.0
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
max_positive=1.0,
max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
)
self.pre_sigmoid = Identity() # before sigmoid; for diagnostics.
@ -1549,11 +1548,11 @@ class ConvolutionModule(nn.Module):
bias=True,
)
self.deriv_balancer2 = ActivationBalancer(
self.balancer2 = ActivationBalancer(
channels, channel_dim=1,
min_positive=ScheduledFloat((0.0, 0.1), (12000.0, 0.05)),
max_abs=ScheduledFloat((0.0, 4.0), (12000.0, 20.0), default=10),
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
max_positive=1.0,
max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 20.0), default=10),
)
self.activation = DoubleSwish()
@ -1587,7 +1586,7 @@ class ConvolutionModule(nn.Module):
"""
x = self.in_proj(x) # (time, batch, 2*channels)
x = self.deriv_balancer1(x)
x = self.balancer1(x)
x, s = x.chunk(2, dim=-1)
s = self.pre_sigmoid(s)
@ -1605,7 +1604,7 @@ class ConvolutionModule(nn.Module):
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.deriv_balancer2(x)
x = self.balancer2(x)
x = x.permute(2, 0, 1) # (time, batch, channels)
x = self.activation(x)