diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 6bf0aefe4..ea0204138 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -325,9 +325,11 @@ class DerivBalancerFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, channel_dim: int, - threshold: float = 0.05, - max_factor: float = 0.05, - min_abs: float = 0.5) -> Tensor: + threshold: float, # e.g. 0.05 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 1000.0 + ) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim @@ -336,23 +338,26 @@ class DerivBalancerFunction(torch.autograd.Function): proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) factor = (threshold - proportion_positive).relu() * (max_factor / threshold) - below_threshold = (torch.mean(x.abs(), dim=sum_dims, keepdim=True) < min_abs) + mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) + below_threshold = (mean_abs < min_abs) + above_threshold = (mean_abs > max_abs) - ctx.save_for_backward(factor, xgt0, below_threshold) + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - factor, xgt0, below_threshold = ctx.saved_tensors + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype - too_small_factor = below_threshold.to(dtype) * (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0) + scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * + (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) - neg_delta_grad = x_grad.abs() * (factor + too_small_factor) + neg_delta_grad = x_grad.abs() * (factor + scale_factor) - return x_grad - neg_delta_grad, None, None, None, None + return x_grad - neg_delta_grad, None, None, None, None, None class BasicNorm(torch.nn.Module): @@ -521,20 +526,33 @@ class DerivBalancer(torch.nn.Module): than zero is that it will tend to prevent situations where the inputs shrink close to zero and the nonlinearity (e.g. swish) behaves like a linear function and we learn nothing. + min_abs: the minimum average-absolute-value per channel, which + we allow, before we start to modify the derivatives to prevent + this. This is to prevent a failure mode where the activations + become so small that the nonlinearity effectively becomes linear, + which makes the module useless and it gets even smaller + to try to "turn it off" completely. + max_abs: the maximum average-absolute-value per channel, which + we allow, before we start to modify the derivatives to prevent + this. This is to prevent the possibility of activations getting + out of floating point numerical range (especially in half precision). """ def __init__(self, channel_dim: int, threshold: float = 0.05, max_factor: float = 0.01, - min_abs: float = 0.2): + min_abs: float = 0.2, + max_abs: float = 1000.0): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim self.threshold = threshold self.max_factor = max_factor self.min_abs = min_abs + self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, - self.max_factor, self.min_abs) + self.max_factor, self.min_abs, + self.max_abs) class DoubleSwish(torch.nn.Module): diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index c2202fe1e..1434d6da4 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved