diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 67671c4a5..1fc46259b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1217,8 +1217,6 @@ class TanSwish(torch.nn.Module): class SwooshLFunction(torch.autograd.Function): """ swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 - - derivatives are between -0.08 and 0.92. """ @staticmethod @@ -1231,19 +1229,21 @@ class SwooshLFunction(torch.autograd.Function): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + coeff = -0.08 + with torch.cuda.amp.autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True - y = torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 + y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 if not requires_grad: return y y.backward(gradient = torch.ones_like(y)) grad = x.grad - floor = -0.1 - ceil = 0.905 # real ceil would be 0.09, give it extra room for roundoff. + floor = coeff + ceil = 1.0 + coeff + 0.005 d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) if __name__ == "__main__": @@ -1261,8 +1261,10 @@ class SwooshLFunction(torch.autograd.Function): def backward(ctx, y_grad: Tensor) -> Tensor: d, = ctx.saved_tensors # the same constants as used in forward pass. - floor = -0.1 - ceil = 0.905 + + coeff = -0.08 + floor = coeff + ceil = 1.0 + coeff + 0.005 d = (d * ((ceil - floor) / 255.0) + floor) return (y_grad * d)