mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp800' into scaled_adam_exp807
This commit is contained in:
commit
a0c35adca0
@ -434,46 +434,43 @@ class MaxEigLimiterFunction(torch.autograd.Function):
|
||||
|
||||
class BasicNormFunction(torch.autograd.Function):
|
||||
# This computes:
|
||||
# scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
|
||||
# scales = ((torch.mean(x**2, keepdim=True) + eps) ** -0.5 * scale)
|
||||
# return (x - bias) * scales
|
||||
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
|
||||
# it can just store the returned value (chances are, this will also be needed for
|
||||
# some other reason, related to the next operation, so we can save memory).
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int,
|
||||
def forward(ctx, x: Tensor, eps: Tensor, scale: Tensor, channel_dim: int,
|
||||
store_output_for_backprop: bool) -> Tensor:
|
||||
assert bias.ndim == 1
|
||||
if channel_dim < 0:
|
||||
channel_dim = channel_dim + x.ndim
|
||||
ctx.store_output_for_backprop = store_output_for_backprop
|
||||
ctx.channel_dim = channel_dim
|
||||
for _ in range(channel_dim + 1, x.ndim):
|
||||
bias = bias.unsqueeze(-1)
|
||||
scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp()
|
||||
scales = ((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + eps) ** -0.5) * scale
|
||||
ans = x * scales
|
||||
ctx.save_for_backward(ans.detach() if store_output_for_backprop else x,
|
||||
scales.detach(), bias.detach(), log_scale.detach())
|
||||
scales.detach(), eps.detach(), scale.detach())
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, ans_grad: Tensor) -> Tensor:
|
||||
ans_or_x, scales, bias, log_scale = ctx.saved_tensors
|
||||
ans_or_x, scales, eps, scale = ctx.saved_tensors
|
||||
if ctx.store_output_for_backprop:
|
||||
x = ans_or_x / scales
|
||||
else:
|
||||
x = ans_or_x
|
||||
x = x.detach()
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
assert eps.dtype != torch.float16 and scale.dtype != torch.float16
|
||||
x = x.to(torch.float32).detach()
|
||||
x.requires_grad = True
|
||||
bias.requires_grad = True
|
||||
log_scale.requires_grad = True
|
||||
eps.requires_grad = True
|
||||
scale.requires_grad = True
|
||||
with torch.enable_grad():
|
||||
# recompute scales from x, bias and log_scale.
|
||||
scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp()
|
||||
# recompute scales from x, epsand log_scale.
|
||||
scales = ((torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) + eps) ** -0.5) * scale
|
||||
ans = x * scales
|
||||
ans.backward(gradient=ans_grad)
|
||||
return x.grad, bias.grad.flatten(), log_scale.grad, None, None
|
||||
ans.backward(gradient=ans_grad.to(torch.float32))
|
||||
return x.grad.to(ans_grad.dtype), eps.grad, scale.grad, None, None
|
||||
|
||||
|
||||
|
||||
@ -500,8 +497,11 @@ class BasicNorm(torch.nn.Module):
|
||||
{-2, -1, 0, 1, 2, 3}.
|
||||
log_scale: the initial log-scale that we multiply the output by; this
|
||||
is learnable.
|
||||
eps: the initial epsilon value (not in log space)
|
||||
log_scale_min: FloatLike, minimum allowed value of log_scale
|
||||
log_scale_max: FloatLike, maximum allowed value of log_scale
|
||||
log_eps_min: FloatLike, minimum allowed value of log_eps
|
||||
log_eps_max: FloatLike, maximum allowed value of log_eps
|
||||
store_output_for_backprop: only possibly affects memory use; recommend
|
||||
to set to True if you think the output of this module is more likely
|
||||
than the input of this module to be required to be stored for the
|
||||
@ -513,17 +513,25 @@ class BasicNorm(torch.nn.Module):
|
||||
num_channels: int,
|
||||
channel_dim: int = -1, # CAUTION: see documentation.
|
||||
log_scale: float = 1.0,
|
||||
log_scale_min: float = -1.5,
|
||||
log_scale_max: float = 1.5,
|
||||
eps: float = 0.1,
|
||||
log_scale_min: FloatLike = -1.5,
|
||||
log_scale_max: FloatLike = 1.5,
|
||||
log_eps_min: FloatLike = -3.0,
|
||||
log_eps_max: FloatLike = 3.0,
|
||||
store_output_for_backprop: bool = False
|
||||
) -> None:
|
||||
super(BasicNorm, self).__init__()
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
self.log_scale = nn.Parameter(torch.tensor(log_scale))
|
||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||
self.log_eps = nn.Parameter(torch.tensor(eps).log().detach())
|
||||
|
||||
self.log_scale_min = log_scale_min
|
||||
self.log_scale_max = log_scale_max
|
||||
|
||||
self.log_eps_min = log_eps_min
|
||||
self.log_eps_max = log_eps_max
|
||||
|
||||
self.store_output_for_backprop = store_output_for_backprop
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
@ -537,7 +545,7 @@ class BasicNorm(torch.nn.Module):
|
||||
bias = self.bias
|
||||
for _ in range(channel_dim + 1, x.ndim):
|
||||
bias = bias.unsqueeze(-1)
|
||||
scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) *
|
||||
scales = (((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + self.eps.exp()) ** -0.5) *
|
||||
self.log_scale.exp())
|
||||
return x * scales
|
||||
|
||||
@ -545,8 +553,12 @@ class BasicNorm(torch.nn.Module):
|
||||
min=float(self.log_scale_min),
|
||||
max=float(self.log_scale_max),
|
||||
training=self.training)
|
||||
log_eps = limit_param_value(self.log_eps,
|
||||
min=float(self.log_eps_min),
|
||||
max=float(self.log_eps_max),
|
||||
training=self.training)
|
||||
|
||||
return BasicNormFunction.apply(x, self.bias, log_scale,
|
||||
return BasicNormFunction.apply(x, log_eps.exp(), log_scale.exp(),
|
||||
self.channel_dim,
|
||||
self.store_output_for_backprop)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user