diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 0b9c7d44a..0e10658fc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -23,6 +23,7 @@ import logging from torch.cuda.amp import custom_fwd, custom_bwd import random import torch +import math import torch.nn as nn import torch.nn.functional as F from torch import Tensor @@ -30,7 +31,6 @@ from torch.nn import Embedding as ScaledEmbedding - class ScheduledFloat(torch.nn.Module): """ This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); @@ -975,6 +975,179 @@ class ActivationBalancer(torch.nn.Module): return _no_op(x) + +class BalancerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + min_mean: float, + max_mean: float, + min_rms: float, + max_rms: float, + grad_scale: float, + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + ctx.save_for_backward(x) + ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) + return x + + + @staticmethod + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None, None, None]: + x, = ctx.saved_tensors + (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config + + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x = x.to(torch.float32) + x = x.detach() + x.requires_grad = True + mean_dims = [ i for i in range(x.ndim) if i != channel_dim ] + uncentered_var = (x ** 2).mean(dim=mean_dims) + mean = x.mean(dim=mean_dims) + stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() + rms = uncentered_var.clamp(min=1.0e-20).sqrt() + + m = mean / stddev + # part of loss that relates to mean / stddev + m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() + + rms_clamped = rms.clamp(min=min_rms, max=max_rms) + r_loss = (rms_clamped / rms).log().abs() + + loss = (m_loss + r_loss).sum() + + loss.backward() + loss_grad = x.grad + loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20) + + loss_grad = loss_grad * (grad_scale / loss_grad_rms) + + x_grad_float = x_grad.to(torch.float32) + # scale each element of loss_grad by the absolute value of the corresponding + # element of x_grad, which we view as a noisy estimate of its magnitude for that + # (frame and dimension). later we can consider factored versions. + x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) + x_grad = x_grad_mod.to(x_grad.dtype) + + return x_grad, None, None, None, None, None, None + + +class Balancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), above which we start to modify the derivatives. + scale_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_abs and max_abs + are violated. + min_abs: the minimum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + max_abs: the maximum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + prob: determines the minimum probability with which we modify the + gradients for the {min,max}_positive and {min,max}_abs constraints, + on each forward(). This is done randomly to prevent all layers + from doing it at the same time. + """ + def __init__( + self, + num_channels: int, + channel_dim: int, + min_positive: FloatLike = 0.05, + max_positive: FloatLike = 0.95, + min_abs: FloatLike = 0.2, + max_abs: FloatLike = 100.0, + grad_scale: FloatLike = 0.04, + prob: Optional[FloatLike] = None, + ): + super().__init__() + + if prob is None: + prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) + self.prob = prob + # 5% of the time we will return and do nothing because memory usage is + # too high. + self.mem_cutoff = CutoffEstimator(0.05) + + # actually self.num_channels is no longer needed except for an assertion. + self.num_channels = num_channels + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.min_abs = min_abs + self.max_abs = max_abs + self.grad_scale = grad_scale + + def forward(self, x: Tensor) -> Tensor: + if (torch.jit.is_scripting() or not x.requires_grad or + (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))): + return _no_op(x) + + prob = float(self.prob) + if random.random() < prob: + # The following inner-functions convert from the way we historically specified + # these limitations, as limits on the absolute value and the proportion of positive + # values, to limits on the RMS value and the (mean / stddev). + def _abs_to_rms(x): + # for normally distributed data, if the expected absolute value is x, the + # expected rms value will be sqrt(pi/2) * x. + return 1.25331413732 * x + def _proportion_positive_to_mean(x): + def _atanh(x): + eps = 1.0e-10 + # eps is to prevent crashes if x is exactly 0 or 1. + # we'll just end up returning a fairly large value. + return (math.log (1+x+eps) - math.log (1-x+eps)) / 2. + def _approx_inverse_erf(x): + # 1 / (sqrt(pi) * ln(2)), + # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions + # this approximation is extremely crude and gets progressively worse for + # x very close to -1 or +1, but we mostly care about the "middle" region + # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772, + # and math.erf(0.0407316414078772) = 0.045935330944660666, + # which is pretty close to 0.05. + return 0.8139535143 * _atanh(x) + # first convert x from the range 0..1 to the range -1..1 which the error + # function returns + x = -1 + (2 * x) + return _approx_inverse_erf(x) + + min_mean = _proportion_positive_to_mean(float(self.min_positive)) + max_mean = _proportion_positive_to_mean(float(self.max_positive)) + min_rms = _abs_to_rms(float(self.min_abs)) + max_rms = _abs_to_rms(float(self.max_abs)) + grad_scale = float(self.grad_scale) + + assert x.shape[self.channel_dim] == self.num_channels + + return BalancerFunction.apply( + x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim + ) + else: + return _no_op(x) + + def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float, name: str = None) -> Tensor: """ @@ -1731,6 +1904,7 @@ def _test_activation_balancer_sign(): max_positive=0.95, max_factor=0.2, min_abs=0.0, + prob=1.0, ) y_grad = torch.sign(torch.randn(probs.numel(), N)) @@ -1742,6 +1916,31 @@ def _test_activation_balancer_sign(): print("_test_activation_balancer_sign: x grad = ", x.grad) +def _test_balancer_sign(): + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) + x = x.detach() + x.requires_grad = True + m = Balancer( + probs.numel(), + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + min_abs=0.0, + prob=1.0, + ) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_balancer_sign: x = ", x) + print("_test_balancer_sign: y grad = ", y_grad) + print("_test_balancer_sign: x grad = ", x.grad) + + + def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 @@ -1770,6 +1969,34 @@ def _test_activation_balancer_magnitude(): print("_test_activation_balancer_magnitude: x grad = ", x.grad) +def _test_balancer_magnitude(): + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) + x = x.detach() + x.requires_grad = True + m = Balancer( + magnitudes.numel(), + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + min_abs=0.2, + max_abs=0.7, + prob=1.0, + ) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_balancer_magnitude: x = ", x) + print("_test_balancer_magnitude: y grad = ", y_grad) + print("_test_balancer_magnitude: x grad = ", x.grad) + + + def _test_basic_norm(): num_channels = 128 m = BasicNorm(num_channels=num_channels, channel_dim=1) @@ -1862,7 +2089,9 @@ if __name__ == "__main__": _test_whiten() _test_max_eig() _test_activation_balancer_sign() + _test_balancer_sign() _test_activation_balancer_magnitude() + _test_balancer_magnitude() _test_basic_norm() _test_double_swish_deriv() _test_tan_swish_deriv() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 763183b92..bbfb292b4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -26,6 +26,7 @@ import random from encoder_interface import EncoderInterface from scaling import ( ActivationBalancer, + Balancer, BasicNorm, ConvNorm1d, ConvNorm2d, @@ -456,7 +457,7 @@ class ZipformerEncoderLayer(nn.Module): self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = ActivationBalancer( + self.balancer = Balancer( embed_dim, channel_dim=-1, min_positive=0.45, max_positive=0.55, min_abs=1.0, max_abs=6.0, @@ -1302,7 +1303,7 @@ class AttentionSqueeze(nn.Module): # instances of this module, the mean absolute values per channel are in # the range 0.1 to 0.4. We apply the upper limit of 0.4 at the # beginning, and make it looser over time. - self.bottleneck_balancer = ActivationBalancer( + self.bottleneck_balancer = Balancer( bottleneck_dim, channel_dim=-1, min_positive=0.2, max_positive=0.8, min_abs=0.05, @@ -1316,13 +1317,13 @@ class AttentionSqueeze(nn.Module): # many degrees of freedom for the scales of the various activations. # Make them run with very low probability, since only a small # application of these balancers should be enough to stop such "drift". - self.scale_balancer = ActivationBalancer( + self.scale_balancer = Balancer( hidden_dim, channel_dim=-1, min_positive=0.2, max_positive=0.8, min_abs=0.2, max_abs=1.0, prob=_balancer_schedule(0.05), ) - self.activation_balancer = ActivationBalancer( + self.activation_balancer = Balancer( hidden_dim, channel_dim=-1, min_positive=0.2, max_positive=0.8, min_abs=0.2, max_abs=1.0, @@ -1339,7 +1340,7 @@ class AttentionSqueeze(nn.Module): self.out_proj = ScaledLinear(hidden_dim, embed_dim, bias=False, initial_scale=0.05) - self.out_balancer = ActivationBalancer( + self.out_balancer = Balancer( embed_dim, channel_dim=-1, min_positive=0.3, max_positive=0.7, min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), @@ -1396,7 +1397,7 @@ class FeedforwardModule(nn.Module): super(FeedforwardModule, self).__init__() self.in_proj = nn.Linear(embed_dim, feedforward_dim) - self.hidden_balancer = ActivationBalancer(feedforward_dim, + self.hidden_balancer = Balancer(feedforward_dim, channel_dim=-1, min_positive=0.3, max_positive=1.0, @@ -1447,7 +1448,7 @@ class NonlinAttentionModule(nn.Module): # because we noticed that well-trained instances of this module have abs-value before the sigmoid # starting from about 3, and poorly-trained instances of the module have smaller abs values # before the sigmoid. - self.balancer1 = ActivationBalancer( + self.balancer1 = Balancer( hidden_channels, channel_dim=-1, min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), @@ -1473,7 +1474,7 @@ class NonlinAttentionModule(nn.Module): prob=(0.025, 0.25), grad_scale=0.01) - self.balancer2 = ActivationBalancer( + self.balancer2 = Balancer( channels, channel_dim=-1, min_positive=0.3, max_positive=0.7, min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), @@ -1566,7 +1567,7 @@ class ConvolutionModule(nn.Module): # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, # it will be in a better position to start learning something, i.e. to latch onto # the correct range. - self.balancer1 = ActivationBalancer( + self.balancer1 = Balancer( bottleneck_dim, channel_dim=-1, min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), max_positive=1.0, @@ -1590,7 +1591,7 @@ class ConvolutionModule(nn.Module): bias=True, ) - self.balancer2 = ActivationBalancer( + self.balancer2 = Balancer( bottleneck_dim, channel_dim=1, min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), max_positive=1.0, @@ -1695,7 +1696,7 @@ class ConvNeXt(nn.Module): out_channels=hidden_channels, kernel_size=1) - self.hidden_balancer = ActivationBalancer(hidden_channels, + self.hidden_balancer = Balancer(hidden_channels, channel_dim=1, min_positive=0.3, max_positive=1.0, @@ -1709,7 +1710,7 @@ class ConvNeXt(nn.Module): kernel_size=1, initial_scale=0.01) - self.out_balancer = ActivationBalancer( + self.out_balancer = Balancer( channels, channel_dim=1, min_positive=0.4, max_positive=0.6, min_abs=1.0, max_abs=6.0, @@ -1800,7 +1801,7 @@ class Conv2dSubsampling(nn.Module): padding=(0, 1), # (time, freq) ), ScaleGrad(0.2), - ActivationBalancer(layer1_channels, + Balancer(layer1_channels, channel_dim=1, max_abs=1.0), SwooshR(), @@ -1811,7 +1812,7 @@ class Conv2dSubsampling(nn.Module): stride=2, padding=0, ), - ActivationBalancer(layer2_channels, + Balancer(layer2_channels, channel_dim=1, max_abs=4.0), SwooshR(), @@ -1833,7 +1834,7 @@ class Conv2dSubsampling(nn.Module): kernel_size=3, stride=(1, 2), # (time, freq) ), - ActivationBalancer(layer3_channels, + Balancer(layer3_channels, channel_dim=1, max_abs=4.0), SwooshR(),