diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index e74acb7fe..f80e42edb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -34,6 +34,7 @@ class ActivationBalancerFunction(torch.autograd.Function): def forward( ctx, x: Tensor, + mean: Tensor, sign_factor: Tensor, scale_factor: Tensor, channel_dim: int, @@ -41,8 +42,13 @@ class ActivationBalancerFunction(torch.autograd.Function): if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = (x > 0) - ctx.save_for_backward(xgt0, sign_factor, scale_factor) + for _ in range(ctx.channel_dim, x.ndim - 1): + mean = mean.unsqueeze(-1) + sign_factor = sign_factor.unsqueeze(-1) + scale_factor = scale_factor.unsqueeze(-1) + + xgtmean = (x > mean) + ctx.save_for_backward(xgtmean, sign_factor, scale_factor) return x @@ -50,14 +56,11 @@ class ActivationBalancerFunction(torch.autograd.Function): def backward( ctx, x_grad: Tensor ) -> Tuple[Tensor, None, None, None]: - xgt0, sign_factor, scale_factor = ctx.saved_tensors - for _ in range(ctx.channel_dim, x_grad.ndim - 1): - sign_factor = sign_factor.unsqueeze(-1) - scale_factor = scale_factor.unsqueeze(-1) + xgtmean, sign_factor, scale_factor = ctx.saved_tensors - factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + factor = sign_factor + scale_factor * (xgtmean.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, + return x_grad - neg_delta_grad, None, None, None, None, @@ -275,6 +278,9 @@ class ActivationBalancer(torch.nn.Module): # count measures how many times the forward() function has been called. self.count = 0 + # the mean of the data per channel + self.register_buffer('mean', torch.zeros(num_channels)) + # the mean of the absolute value of the data per channel self.register_buffer('abs_mean', torch.zeros(num_channels)) @@ -307,7 +313,7 @@ class ActivationBalancer(torch.nn.Module): sign_factor = factors[0] scale_factor = factors[1] return ActivationBalancerFunction.apply( - x, sign_factor, scale_factor, self.channel_dim, + x, self.mean, sign_factor, scale_factor, self.channel_dim, ) else: return x @@ -322,6 +328,7 @@ class ActivationBalancer(torch.nn.Module): with torch.no_grad(): sum_dims = [d for d in range(x.ndim) if d != self.channel_dim] + x_mean = torch.mean(x, dim=sum_dims).to(torch.float32) x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) # the random.random() thing is to split the difference if x is zero, # between treating it positive or negative @@ -333,9 +340,11 @@ class ActivationBalancer(torch.nn.Module): mask = (y - y != 0) y.masked_fill_(mask, 0.0) + filter_inf_nan(x_mean) filter_inf_nan(x_abs_mean) beta = self.beta if count > 0 else 0.0 + self.mean.mul_(beta).add_(x_mean, alpha=(1-beta)) self.abs_mean.mul_(beta).add_(x_abs_mean, alpha=(1-beta)) self.proportion_positive.mul_(beta).add_(proportion_positive, alpha=(1-beta)) @@ -363,25 +372,11 @@ class ActivationBalancer(torch.nn.Module): # the factor of 2.0 below is just to cancel out a factor of 0.5 that gets introduced when, in # the backprop, we do (xgt0.to(dtype) - 0.5). - # - # scale_factor_scale, on the other hand, is a heuristically chosen value between 0 and 1, - # that we use to make the gradient changes from the 'scale' constraints (min_abs/max_abs) - # less strong than those from the sign constraints. - # - # This is to get rid of a pathology that can happen if, for instance, a - # channel is always positive but is too small (max_positive and min_abs constraints both - # violated). If scale_factor_scale were equal to 1.0, then the gradient changes from the - # min_positive constraint (trying to make the activation more negative) and from the - # min_abs constraint (trying to make the activation more positive) would exactly cancel. - # Instead we make the min_positive constraint stronger, so it first makes the value - # sometimes negative, and only when that is satisfied, can deal with the absolute-value - # constraint. - scale_factor_scale = 0.5 below_threshold = (self.abs_mean < self.min_abs) above_threshold = (self.abs_mean > self.max_abs) scale_factor[:] = ((below_threshold.to(torch.float32) - above_threshold.to(torch.float32)) - * (max_factor * (2.0 * scale_factor_scale))) + * (max_factor * 2.0)) class MaxEig(torch.nn.Module):