diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 500cacca8..0a39b0f33 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -47,12 +47,12 @@ class Conv2dSubsampling(nn.Module): ScaledConv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), - DerivBalancer(channel_dim=1), + ActivationBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), - DerivBalancer(channel_dim=1), + ActivationBalancer(channel_dim=1), DoubleSwish(), ) self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -61,7 +61,9 @@ class Conv2dSubsampling(nn.Module): # needed. self.out_norm = BasicNorm(odim, learn_eps=False) # constrain median of output to be close to zero. - self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) + self.out_balancer = ActivationBalancer(channel_dim=-1, + min_positive=0.45, + max_positive=0.55) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -177,7 +179,7 @@ class VggSubsampling(nn.Module): -class DerivBalancerFunction(torch.autograd.Function): +class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, channel_dim: int, @@ -428,44 +430,33 @@ class ScaledConv2d(nn.Conv2d): -class DerivBalancer(torch.nn.Module): +class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for each channel, that it is positive at least a proportion `threshold` of the time. It does this by multiplying negative derivative values by up to (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 0 at the threshold to those extremal values when none + interpolated from 1 at the threshold to those extremal values when none of the inputs are positive. - When all grads are zero for a channel, this - module sets all the input derivatives for that channel to -epsilon; the - idea is to bring completely dead neurons back to life this way. Args: - channel_dim: the dimension/axi corresponding to the channel, e.g. + channel_dim: the dimension/axis corresponding to the channel, e.g. -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. min_positive: the minimum, per channel, of the proportion of the time that (x > 0), below which we start to modify the derivatives. max_positive: the maximum, per channel, of the proportion of the time that (x > 0), below which we start to modify the derivatives. - max_factor: the maximum factor by which we modify the derivatives, + max_factor: the maximum factor by which we modify the derivatives for + either the sign constraint or the magnitude constraint; e.g. with max_factor=0.02, the the derivatives would be multiplied by - values in the range [0.98..1.01]. - zero: we use this value in the comparison (x > 0), i.e. we actually use - (x > zero). The reason for using a threshold slightly greater - than zero is that it will tend to prevent situations where the - inputs shrink close to zero and the nonlinearity (e.g. swish) - behaves like a linear function and we learn nothing. + values in the range [0.98..1.02]. min_abs: the minimum average-absolute-value per channel, which we allow, before we start to modify the derivatives to prevent - this. This is to prevent a failure mode where the activations - become so small that the nonlinearity effectively becomes linear, - which makes the module useless and it gets even smaller - to try to "turn it off" completely. + this. max_abs: the maximum average-absolute-value per channel, which we allow, before we start to modify the derivatives to prevent - this. This is to prevent the possibility of activations getting - out of floating point numerical range (especially in half precision). + this. """ def __init__(self, channel_dim: int, min_positive: float = 0.05, @@ -473,7 +464,7 @@ class DerivBalancer(torch.nn.Module): max_factor: float = 0.01, min_abs: float = 0.2, max_abs: float = 100.0): - super(DerivBalancer, self).__init__() + super(ActivationBalancer, self).__init__() self.channel_dim = channel_dim self.min_positive = min_positive self.max_positive = max_positive @@ -482,10 +473,10 @@ class DerivBalancer(torch.nn.Module): self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - return DerivBalancerFunction.apply(x, self.channel_dim, - self.min_positive, self.max_positive, - self.max_factor, self.min_abs, - self.max_abs) + return ActivationBalancerFunction.apply(x, self.channel_dim, + self.min_positive, self.max_positive, + self.max_factor, self.min_abs, + self.max_abs) def _double_swish(x: Tensor) -> Tensor: @@ -524,8 +515,8 @@ def _test_deriv_balancer_sign(): x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = DerivBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, - max_factor=0.2, min_abs=0.0) + m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, + max_factor=0.2, min_abs=0.0) y_grad = torch.sign(torch.randn(probs.numel(), N)) @@ -542,10 +533,10 @@ def _test_deriv_balancer_magnitude(): x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True - m = DerivBalancer(channel_dim=0, - min_positive=0.0, max_positive=1.0, - max_factor=0.2, - min_abs=0.2, max_abs=0.8) + m = ActivationBalancer(channel_dim=0, + min_positive=0.0, max_positive=1.0, + max_factor=0.2, + min_abs=0.2, max_abs=0.8) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 8de02628d..6278734e5 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import DoubleSwish, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d +from subsampling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d import torch from torch import Tensor, nn @@ -159,7 +159,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), - DerivBalancer(channel_dim=-1), + ActivationBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), @@ -167,7 +167,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward_macaron = nn.Sequential( ScaledLinear(d_model, dim_feedforward), - DerivBalancer(channel_dim=-1), + ActivationBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), @@ -180,7 +180,9 @@ class ConformerEncoderLayer(nn.Module): self.norm_final = BasicNorm(d_model) # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) + self.balancer = ActivationBalancer(channel_dim=-1, + min_positive=0.45, + max_positive=0.55) self.dropout = nn.Dropout(dropout) @@ -858,8 +860,9 @@ class ConvolutionModule(nn.Module): # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, # it will be in a better position to start learning something, i.e. to latch onto # the correct range. - self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0, - min_positive=0.05, max_positive=1.0) + self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0, + min_positive=0.05, + max_positive=1.0) self.depthwise_conv = ScaledConv1d( channels, @@ -871,8 +874,9 @@ class ConvolutionModule(nn.Module): bias=bias, ) - self.deriv_balancer2 = DerivBalancer(channel_dim=1, - min_positive=0.05, max_positive=1.0) + self.deriv_balancer2 = ActivationBalancer(channel_dim=1, + min_positive=0.05, + max_positive=1.0) self.activation = DoubleSwish()