Use the balancer; remove the unused sigmoid module.

This commit is contained in:
Daniel Povey 2022-11-03 19:21:37 +08:00
parent a9c384e69e
commit f625810de1

View File

@ -1436,14 +1436,9 @@ class ModifiedSEModule(nn.Module):
d_model: int,
bottleneck_dim: int = 8):
super().__init__()
self.squeeze_proj = nn.Linear(d_model, d_model,
self.squeeze_proj = nn.Linear(d_model, bottleneck_dim,
bias=False)
# caution: this won't work well if the batch size is extremely small.
self.squeeze_whiten = Whiten(num_groups=1,
whitening_limit=10.0,
prob=(0.025, 0.25),
grad_scale=0.01)
self.in_proj = nn.Linear(d_model, d_model,
bias=False)
@ -1456,23 +1451,14 @@ class ModifiedSEModule(nn.Module):
self.balancer = ActivationBalancer(
d_model, channel_dim=-1,
min_positive=0.05, max_positive=0.95,
min_abs=0.1,
max_abs=50.0,
max_factor=0.01,
max_factor=0.02,
min_prob=0.2,
)
self.activation = DoubleSwish()
self.to_bottleneck_proj = ScaledLinear(d_model, bottleneck_dim)
self.bottleneck_balancer = ActivationBalancer(
bottleneck_dim, channel_dim=-1,
min_positive=0.05, max_positive=0.95,
max_abs=5.0,
min_abs=0.5,
max_factor=0.01,
min_prob=0.2,
)
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, d_model)
self.sigmoid = nn.Sigmoid() # make it a submodule for diagnostics purposes.
self.out_proj = ScaledLinear(d_model, d_model,
bias=False, initial_scale=0.1)
@ -1501,17 +1487,9 @@ class ModifiedSEModule(nn.Module):
squeezed = (x * pooling_mask).sum(dim=0, keepdim=True)
squeezed = self.squeeze_proj(squeezed)
squeezed = self.squeeze_whiten(squeezed)
squeezed = self.balancer(squeezed)
squeezed = self.activation(squeezed)
squeezed = self.to_bottleneck_proj(squeezed)
squeezed = self.bottleneck_balancer(squeezed)
squeezed = self.from_bottleneck_proj(squeezed)
if random.random() < 0.05:
# to stop a hopefully-unlikely failure mode where the inputs to the sigmoid
# get too large and the grads get mostly too small.
squeezed = penalize_abs_values_gt(squeezed, limit=10.0, penalty=1.0e-04)
scales = self.sigmoid(squeezed)
x = self.in_proj(x)
x = x * squeezed