diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index fe97ae0a0..cdbd781f1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -30,6 +30,104 @@ from torch.nn import Embedding as ScaledEmbedding class ActivationBalancerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + xgt0 = (x > 0) + if sign_factor is None: + ctx.save_for_backward(xgt0, scale_factor) + else: + ctx.save_for_backward(xgt0, scale_factor, sign_factor) + return x + + + @staticmethod + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None]: + if len(ctx.saved_tensors) == 3: + xgt0, scale_factor, sign_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + scale_factor = scale_factor.unsqueeze(-1) + sign_factor = sign_factor.unsqueeze(-1) + factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + else: + xgt0, scale_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + scale_factor = scale_factor.unsqueeze(-1) + factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + neg_delta_grad = x_grad.abs() * factor + return x_grad - neg_delta_grad, None, None, None, + +def _compute_scale_factor(x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) + + if min_abs == 0.0: + below_threshold = 0.0 + else: + # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if + # x_abs)_mean , min_abs. + below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor) + + above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor) + + return below_threshold - above_threshold + +def _compute_sign_factor(x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + proportion_positive = torch.mean((x > 0).to(torch.float32), + dim=sum_dims) + if min_positive == 0.0: + factor1 = 0.0 + else: + # 0 if proportion_positive >= min_positive, else can be + # as large as max_factor. + factor1 = ((min_positive - proportion_positive) * + (gain_factor / min_positive)).clamp_(min=0, max=max_factor) + + if max_positive == 1.0: + factor2 = 0.0 + else: + # 0 if self.proportion_positive <= max_positive, else can be + # as large as -max_factor. + factor2 = ((proportion_positive - max_positive) * + (gain_factor / (1.0 - max_positive))).clamp_(min=0, max=max_factor) + sign_factor = factor1 - factor2 + # require min_positive != 0 or max_positive != 1: + assert not isinstance(sign_factor, float) + return sign_factor + + + +class ActivationScaleBalancerFunction(torch.autograd.Function): + """ + This object is used in class ActivationBalancer when the user specified + min_positive=0, max_positive=1, so there are no constraints on the signs + of the activations and only the absolute value has a constraint. + """ @staticmethod def forward( ctx, @@ -62,6 +160,7 @@ class ActivationBalancerFunction(torch.autograd.Function): + class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod def forward( @@ -218,7 +317,6 @@ class ActivationBalancer(torch.nn.Module): interpolated from 1 at the threshold to those extremal values when none of the inputs are positive. - Args: num_channels: the number of channels channel_dim: the dimension/axis corresponding to the channel, e.g. @@ -231,34 +329,36 @@ class ActivationBalancer(torch.nn.Module): either the sign constraint or the magnitude constraint; e.g. with max_factor=0.02, the the derivatives would be multiplied by values in the range [0.98..1.02]. + sign_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_positive and max_positive + are violated. + scale_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_abs and max_abs + are violated. min_abs: the minimum average-absolute-value difference from the mean value per channel, which we allow, before we start to modify the derivatives to prevent this. max_abs: the maximum average-absolute-value difference from the mean value per channel, which we allow, before we start to modify the derivatives to prevent this. - beta: a constant used in decaying stats for the {min,max}_positive and - {min,max}_abs constraints. Likely not critical. - prob: determines the probability with which we modify the + min_prob: determines the minimum probability with which we modify the gradients for the {min,max}_positive and {min,max}_abs constraints, on each forward(). This is done randomly to prevent all layers - from doing it at the same time. - stats_period: the periodicity with which we update the statistics on - the activations. + from doing it at the same time. Early in training we may use + higher probabilities than this; it will decay to this value. """ def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0, - max_var_per_eig: float = 0.0, - beta: float = 0.0, - prob: float = 0.25, - stats_period: int = 4, + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.02, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, ): super(ActivationBalancer, self).__init__() self.num_channels = num_channels @@ -268,124 +368,58 @@ class ActivationBalancer(torch.nn.Module): self.max_factor = max_factor self.min_abs = min_abs self.max_abs = max_abs - self.beta = beta - self.prob = prob - self.stats_period = stats_period + self.min_prob = min_prob + self.sign_gain_factor = sign_gain_factor + self.scale_gain_factor = scale_gain_factor # count measures how many times the forward() function has been called. - self.count = 0 + # We occasionally sync this to a tensor called `count`, that exists to + # make sure it is synced to disk when we load and save the model. + self.cpu_count = 0 + self.register_buffer('count', torch.tensor(0, dtype=torch.int64)) - # the mean of the absolute value of the data per channel - self.register_buffer('abs_mean', torch.zeros(num_channels)) - - # the proportion of activations that are positive, per channel. - self.register_buffer('proportion_positive', torch.zeros(num_channels)) - - # `factors` contains two buffers of shape (num_channels,). - # `sign_factor` is an expression that will be used to scale the - # gradients in backprop; it will be 0 if the max_positive and min_positive - # contstraints are satisfied. - # `scale_factor` is an expression that will be used to encourage the - # data to satisfy our min_abs and max_abs constraints; it will be zero if - # all constraints are satisfied. - self.register_buffer('factors', torch.zeros(2, num_channels)) def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting() or not x.requires_grad: return x - count = self.count - self.count += 1 + count = self.cpu_count + self.cpu_count += 1 - if count % self.stats_period == 0: - self._update_stats(x, count) + if random.random() < 0.01: + # Occasionally sync self.cpu_count with self.count. + # count affects the decay of 'prob'. don't do this on every iter, + # because syncing with the GPU is slow. + self.cpu_count = max(self.cpu_count, self.count.item()) + self.count.fill_(self.cpu_count) - if random.random() < self.prob: - # The .clone() is in case the forward() gets called multiple times befor - factors = self.factors.clone() - sign_factor = factors[0] - scale_factor = factors[1] + # the prob of doing some work exponentially decreases from 0.5 till it hits + # a floor at min_prob (==0.1, by default) + prob = max(self.min_prob, 0.5 ** (1 + (count/2000.0))) + + if random.random() < prob: + sign_gain_factor = 0.5 + if self.min_positive != 0.0 or self.max_positive != 1.0: + sign_factor = _compute_sign_factor(x, self.channel_dim, + self.min_positive, self.max_positive, + gain_factor=self.sign_gain_factor / prob, + max_factor=self.max_factor) + else: + sign_factor = None + + + scale_factor = _compute_scale_factor(x, self.channel_dim, + min_abs=self.min_abs, + max_abs=self.max_abs, + gain_factor=self.scale_gain_factor / prob, + max_factor=self.max_factor) return ActivationBalancerFunction.apply( - x, sign_factor, scale_factor, self.channel_dim, + x, scale_factor, sign_factor, self.channel_dim, ) else: return x - def _update_stats(self, - x: Tensor, - count: int): - """ - Updates some statistics that we maintain, describing the average activations per - channel. - """ - with torch.no_grad(): - channel_dim = self.channel_dim - if channel_dim < 0: - channel_dim += x.ndim - sum_dims = [d for d in range(x.ndim) if d != channel_dim] - - x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) - # the random.random() thing is to split the difference if x is zero, - # between treating it positive or negative - proportion_positive = torch.mean( - ((x > 0) if random.random() < 0.5 else (x >= 0)).to(torch.float32), dim = sum_dims, - ) - - def filter_inf_nan(y): - mask = (y - y != 0) - y.masked_fill_(mask, 0.0) - - filter_inf_nan(x_abs_mean) - - beta = self.beta if count > 0 else 0.0 - self.abs_mean.mul_(beta).add_(x_abs_mean, alpha=(1-beta)) - self.proportion_positive.mul_(beta).add_(proportion_positive, alpha=(1-beta)) - - max_factor = self.max_factor / self.prob - min_positive = self.min_positive - max_positive = self.max_positive - - if min_positive == 0.0: - factor1 = 0.0 - else: - # 0 if self.proportion_positive >= min_positive, else can be - # as large as max_factor. - factor1 = ((min_positive - self.proportion_positive).relu() * - (max_factor / min_positive)) - if max_positive == 1.0: - factor2 = 0.0 - else: - # 0 if self.proportion_positive <= max_positive, else can be - # as large as -max_factor. - factor2 = ((self.proportion_positive - max_positive).relu() - * (max_factor / (max_positive - 1.0))) - sign_factor = self.factors[0] - scale_factor = self.factors[1] - sign_factor[:] = factor1 + factor2 - - # the factor of 2.0 below is just to cancel out a factor of 0.5 that gets introduced when, in - # the backprop, we do (xgt0.to(dtype) - 0.5). - # - # scale_factor_scale, on the other hand, is a heuristically chosen value between 0 and 1, - # that we use to make the gradient changes from the 'scale' constraints (min_abs/max_abs) - # less strong than those from the sign constraints. - # - # This is to get rid of a pathology that can happen if, for instance, a - # channel is always positive but is too small (max_positive and min_abs constraints both - # violated). If scale_factor_scale were equal to 1.0, then the gradient changes from the - # min_positive constraint (trying to make the activation more negative) and from the - # min_abs constraint (trying to make the activation more positive) would exactly cancel. - # Instead we make the min_positive constraint stronger, so it first makes the value - # sometimes negative, and only when that is satisfied, can deal with the absolute-value - # constraint. - scale_factor_scale = 0.8 - below_threshold = (self.abs_mean < self.min_abs) - above_threshold = (self.abs_mean > self.max_abs) - scale_factor[:] = ((below_threshold.to(torch.float32) - - above_threshold.to(torch.float32)) - * (max_factor * (2.0 * scale_factor_scale))) - class MaxEig(torch.nn.Module): """ @@ -612,7 +646,6 @@ def _test_activation_balancer_sign(): max_positive=0.95, max_factor=0.2, min_abs=0.0, - prob=1.0, ) y_grad = torch.sign(torch.randn(probs.numel(), N)) @@ -640,7 +673,7 @@ def _test_activation_balancer_magnitude(): max_factor=0.2, min_abs=0.2, max_abs=0.8, - prob=1.0, + min_prob=1.0, ) y_grad = torch.sign(torch.randn(magnitudes.numel(), N))