From a55f8c9c1465c7d681699e83ae577a4e43986e6c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 29 Jul 2022 09:34:13 +0800 Subject: [PATCH] Modify scaling.py to prevent constant values --- .../pruned_transducer_stateless7/scaling.py | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index ed22a6315..8ae390f45 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -58,9 +58,10 @@ class ActivationBalancerFunction(torch.autograd.Function): if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - xgt0 = x > 0 + x_normalized = x - torch.mean(x, dim=sum_dims, keepdim=True) + xgtmean = (x_normalized > 0) proportion_positive = torch.mean( - xgt0.to(x.dtype), dim=sum_dims, keepdim=True + (x > 0).to(x.dtype), dim=sum_dims, keepdim=True ) factor1 = ( (min_positive - proportion_positive).relu() @@ -74,16 +75,24 @@ class ActivationBalancerFunction(torch.autograd.Function): if max_positive != 1.0 else 0.0 ) + # `factor` is a tensor of shape something like (1, 1, num_channels, + # 1), containing elements between -1 and 1 that are zero if the + # proportion of positive features is between min_positive and + # max_positive, max_factor if proportion==0.0 (all features are negative), + # and -max_factor if proportion==1.0 (all features are positive). It is + # an amount per channel by which we'll modify the gradient; the sign + # of modifying the gradient will depend on the sign of the gradient. + factor = factor1 + factor2 if isinstance(factor, float): factor = torch.zeros_like(proportion_positive) - mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) + mean_abs = torch.mean(x_normalized.abs(), dim=sum_dims, keepdim=True) below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs ctx.save_for_backward( - factor, xgt0, below_threshold, above_threshold + factor, xgtmean, below_threshold, above_threshold ) ctx.max_factor = max_factor ctx.sum_dims = sum_dims @@ -93,11 +102,11 @@ class ActivationBalancerFunction(torch.autograd.Function): def backward( ctx, x_grad: Tensor ) -> Tuple[Tensor, None, None, None, None, None, None]: - factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors + factor, xgtmean, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype scale_factor = ( (below_threshold.to(dtype) - above_threshold.to(dtype)) - * (xgt0.to(dtype) - 0.5) + * (xgtmean.to(dtype) - 0.5) * (ctx.max_factor * 2.0) ) @@ -105,6 +114,8 @@ class ActivationBalancerFunction(torch.autograd.Function): return x_grad - neg_delta_grad, None, None, None, None, None, None + + class BasicNorm(torch.nn.Module): """ This is intended to be a simpler, and hopefully cheaper, replacement for @@ -404,12 +415,12 @@ class ActivationBalancer(torch.nn.Module): either the sign constraint or the magnitude constraint; e.g. with max_factor=0.02, the the derivatives would be multiplied by values in the range [0.98..1.02]. - min_abs: the minimum average-absolute-value per channel, which - we allow, before we start to modify the derivatives to prevent - this. - max_abs: the maximum average-absolute-value per channel, which - we allow, before we start to modify the derivatives to prevent - this. + min_abs: the minimum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + max_abs: the maximum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. """ def __init__(