From 049174722f1f065e2941b6fff192443109473e39 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 23 Dec 2022 13:16:51 +0800 Subject: [PATCH] Change BasicNorm by adding 1+eps denominator; fix to (unused) DoubleSwish, revert to old status. --- .../pruned_transducer_stateless7/scaling.py | 33 +++++-------------- 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index c5b748480..679fe552c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -491,8 +491,10 @@ class BasicNorm(torch.nn.Module): # gradients to allow the parameter to get back into the allowed # region if it happens to exit it. eps = eps.clamp(min=self.eps_min, max=self.eps_max) + eps = eps.exp() scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() + (torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps) / + (1.0 + eps) ) ** -0.5 return x * scales @@ -1330,9 +1332,8 @@ class MaxEig(torch.nn.Module): class DoubleSwishFunction(torch.autograd.Function): """ - double_swish(x) = x * (torch.sigmoid(x-1) + alpha) + double_swish(x) = x * torch.sigmoid(x-1) - for e.g. alpha=-0.05 (user supplied). This is a definition, originally motivated by its close numerical similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). @@ -1356,18 +1357,8 @@ class DoubleSwishFunction(torch.autograd.Function): s = torch.sigmoid(x - 1.0) y = x * s - alpha = -0.05 - beta = 0.05 - x_limit = 0.15 - - # another part of this formula is: - # ... + 0.2 * x.clamp(min=-0.15, max=0.15) - # the deriv of this is - # beta * (x.abs() < x_limit). - if requires_grad: - deriv = (y * (1 - s) + s) # ignores the alpha part. - deriv = deriv + (x.abs() < x_limit) * beta + deriv = (y * (1 - s) + s) # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 @@ -1376,7 +1367,7 @@ class DoubleSwishFunction(torch.autograd.Function): # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which # floors), should be expectation-preserving. floor = -0.044 - ceil = 1.2 + beta + ceil = 1.2 d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) if __name__ == "__main__": # for self-testing only. @@ -1384,8 +1375,6 @@ class DoubleSwishFunction(torch.autograd.Function): assert d_scaled.max() < 256.0 d_int = d_scaled.to(torch.uint8) ctx.save_for_backward(d_int) -# on wolframalpha, do: (x * sigmoid(x-1) - 0.05 * x + 0.05 * min(0.15, max(-0.15, x)) + 0.025) from x=-3 to 2 - y = y + alpha * x + beta * x.clamp(min=-x_limit, max=x_limit) - 0.025 if x.dtype == torch.float16 or torch.is_autocast_enabled(): y = y.to(torch.float16) return y @@ -1394,13 +1383,11 @@ class DoubleSwishFunction(torch.autograd.Function): def backward(ctx, y_grad: Tensor) -> Tensor: d, = ctx.saved_tensors # the same constants as used in forward pass. - alpha = -0.05 - beta = 0.05 floor = -0.043637 - ceil = 1.2 + beta + ceil = 1.2 d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * (d + alpha)) + return y_grad * d class DoubleSwish(torch.nn.Module): def __init__(self): @@ -1412,7 +1399,7 @@ class DoubleSwish(torch.nn.Module): that we approximate closely with x * sigmoid(x-1). """ if torch.jit.is_scripting(): - return x * (torch.sigmoid(x - 1.0) - 0.05) + 0.05 * x.clamp(min=-0.15, max=0.15) + return x * torch.sigmoid(x - 1.0) return DoubleSwishFunction.apply(x) @@ -1741,8 +1728,6 @@ def _test_basic_norm(): y_rms = (y ** 2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) - assert y_rms < x_rms - assert y_rms > 0.5 * x_rms def _test_double_swish_deriv():