diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 1a230b24b..36aa5f660 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -105,17 +105,18 @@ class ActivationBalancerFunction(torch.autograd.Function): ctx, x: Tensor, scale_factor: Tensor, + mean: Tensor, sign_factor: Optional[Tensor], channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = (x > 0) + xgtmean = (x > mean) if sign_factor is None: - ctx.save_for_backward(xgt0, scale_factor) + ctx.save_for_backward(xgtmean, scale_factor) else: - ctx.save_for_backward(xgt0, scale_factor, sign_factor) + ctx.save_for_backward(xgtmean, scale_factor, sign_factor) return x @@ -124,29 +125,48 @@ class ActivationBalancerFunction(torch.autograd.Function): ctx, x_grad: Tensor ) -> Tuple[Tensor, None, None, None]: if len(ctx.saved_tensors) == 3: - xgt0, scale_factor, sign_factor = ctx.saved_tensors + xgtmean, scale_factor, sign_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): scale_factor = scale_factor.unsqueeze(-1) sign_factor = sign_factor.unsqueeze(-1) - factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + factor = sign_factor + scale_factor * (xgtmean.to(x_grad.dtype) - 0.5) else: - xgt0, scale_factor = ctx.saved_tensors + xgtmean, scale_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): scale_factor = scale_factor.unsqueeze(-1) - factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + factor = scale_factor * (xgtmean.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, + return x_grad - neg_delta_grad, None, None, None, None def _compute_scale_factor(x: Tensor, channel_dim: int, min_abs: float, max_abs: float, gain_factor: float, - max_factor: float) -> Tensor: + max_factor: float) -> Tuple[Tensor, Tensor]: + """ + Computes a factor used in ActivationBalancer, that dictates how much we penalize (or anti-penalize) + the scale on the features. + + Returns: (scale_factor, mean) + dim. + scale_factor: can be positive or negative, between -max_factor and max_factor; dictates + penalty or anti-penalty. It is of shape (num_channels,) + mean: mean per channel that we use for purposes of scale_factor; actually is clamped to + -min_abs..min_abs. Its like (1, num_channels, 1, 1) depending on the shape of x and + channel-dim. + + + """ if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) + + x_mean = torch.mean(x, dim=sum_dims, keepdim=True).to(torch.float32) + # the idea is that for purposes of applying max_abs, we regress effectively + # toward zero (assuming min_abs is much less than max_abs). + x_mean = x_mean.clamp(min=-min_abs, max=min_abs) + x_abs_mean = torch.mean((x - x_mean).abs(), dim=sum_dims).to(torch.float32) if min_abs == 0.0: below_threshold = 0.0 @@ -157,7 +177,7 @@ def _compute_scale_factor(x: Tensor, above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor) - return below_threshold - above_threshold + return below_threshold - above_threshold, x_mean def _compute_sign_factor(x: Tensor, channel_dim: int, @@ -679,13 +699,13 @@ class ActivationBalancer(torch.nn.Module): sign_factor = None - scale_factor = _compute_scale_factor(x, self.channel_dim, - min_abs=float(self.min_abs), - max_abs=float(self.max_abs), - gain_factor=float(self.scale_gain_factor) / prob, - max_factor=float(self.max_factor)) + scale_factor, mean = _compute_scale_factor(x, self.channel_dim, + min_abs=float(self.min_abs), + max_abs=float(self.max_abs), + gain_factor=float(self.scale_gain_factor) / prob, + max_factor=float(self.max_factor)) return ActivationBalancerFunction.apply( - x, scale_factor, sign_factor, self.channel_dim, + x, scale_factor, mean, sign_factor, self.channel_dim, ) else: return _no_op(x)