Merge branch 'scaled_adam_exp758' into scaled_adam_exp759

# Conflicts:
#	egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py
This commit is contained in:
Daniel Povey 2022-12-22 17:37:22 +08:00
commit 56fcb14e18

View File

@ -432,8 +432,8 @@ class MaxEigLimiterFunction(torch.autograd.Function):
class BasicNormFunction(torch.autograd.Function): class BasicNormFunction(torch.autograd.Function):
# This computes: # This computes:
# scales = torch.mean((x + bias) ** 2, keepdim=True) + eps.exp() # scales = torch.mean((x - bias) ** 2, keepdim=True) + eps.exp()
# return x * scales # return (x - bias) * scales
# (after unsqueezing the bias), but it does it in a memory-efficient way so that # (after unsqueezing the bias), but it does it in a memory-efficient way so that
# it can just store the returned value (chances are, this will also be needed for # it can just store the returned value (chances are, this will also be needed for
# some other reason, related to the next operation, so we can save memory). # some other reason, related to the next operation, so we can save memory).
@ -448,9 +448,9 @@ class BasicNormFunction(torch.autograd.Function):
ctx.channel_dim = channel_dim ctx.channel_dim = channel_dim
for _ in range(channel_dim + 1, x.ndim): for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1) bias = bias.unsqueeze(-1)
scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5 scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
ans = x * scales - bias ans = x * scales
ctx.save_for_backward(ans.detach() if store_output_for_backprop else x.detach(), ctx.save_for_backward(ans.detach() if store_output_for_backprop else x,
scales.detach(), bias.detach(), eps.detach()) scales.detach(), bias.detach(), eps.detach())
return ans return ans
@ -468,8 +468,8 @@ class BasicNormFunction(torch.autograd.Function):
eps.requires_grad = True eps.requires_grad = True
with torch.enable_grad(): with torch.enable_grad():
# recompute scales from x, bias and eps. # recompute scales from x, bias and eps.
scales = (torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5 scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5
ans = x * scales - bias ans = x * scales
ans.backward(gradient=ans_grad) ans.backward(gradient=ans_grad)
return x.grad, bias.grad.flatten(), eps.grad, None, None return x.grad, bias.grad.flatten(), eps.grad, None, None