mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Revert BasicNorm to its previous status, without the bias
This commit is contained in:
parent
b2125535fb
commit
ade7db54e3
@ -430,51 +430,6 @@ class MaxEigLimiterFunction(torch.autograd.Function):
|
||||
return x_grad + x_extra_grad.detach(), None, None, None, None
|
||||
|
||||
|
||||
class BasicNormFunction(torch.autograd.Function):
|
||||
# This computes:
|
||||
# scales = torch.mean((x - bias) ** 2, keepdim=True) + eps.exp()
|
||||
# return (x - bias) * scales
|
||||
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
|
||||
# it can just store the returned value (chances are, this will also be needed for
|
||||
# some other reason, related to the next operation, so we can save memory).
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x: Tensor, bias: Tensor, eps: Tensor, channel_dim: int,
|
||||
store_output_for_backprop: bool) -> Tensor:
|
||||
assert bias.ndim == 1
|
||||
if channel_dim < 0:
|
||||
channel_dim = channel_dim + x.ndim
|
||||
ctx.store_output_for_backprop = store_output_for_backprop
|
||||
ctx.channel_dim = channel_dim
|
||||
for _ in range(channel_dim + 1, x.ndim):
|
||||
bias = bias.unsqueeze(-1)
|
||||
scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
||||
ans = x * scales
|
||||
ctx.save_for_backward(ans.detach() if store_output_for_backprop else x,
|
||||
scales.detach(), bias.detach(), eps.detach())
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, ans_grad: Tensor) -> Tensor:
|
||||
ans_or_x, scales, bias, eps = ctx.saved_tensors
|
||||
if ctx.store_output_for_backprop:
|
||||
x = ans_or_x / scales
|
||||
else:
|
||||
x = ans_or_x
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
bias.requires_grad = True
|
||||
eps.requires_grad = True
|
||||
with torch.enable_grad():
|
||||
# recompute scales from x, bias and eps.
|
||||
scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
||||
ans = x * scales
|
||||
ans.backward(gradient=ans_grad)
|
||||
return x.grad, bias.grad.flatten(), eps.grad, None, None
|
||||
|
||||
|
||||
|
||||
class BasicNorm(torch.nn.Module):
|
||||
"""
|
||||
This is intended to be a simpler, and hopefully cheaper, replacement for
|
||||
@ -501,13 +456,10 @@ class BasicNorm(torch.nn.Module):
|
||||
to indicate the connection with conventional LayerNorm.
|
||||
learn_eps: if true, we learn epsilon; if false, we keep it
|
||||
at the initial value.
|
||||
store_output_for_backprop: this option makes no difference
|
||||
to the output, but may affect memory usage; determines
|
||||
whether, for backprop purposes, we store the input or the output
|
||||
of this module.
|
||||
eps_min: float
|
||||
eps_max: float
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
@ -516,7 +468,6 @@ class BasicNorm(torch.nn.Module):
|
||||
learn_eps: bool = True,
|
||||
eps_min: float = -3.0,
|
||||
eps_max: float = 3.0,
|
||||
store_output_for_backprop: bool = True
|
||||
) -> None:
|
||||
super(BasicNorm, self).__init__()
|
||||
self.num_channels = num_channels
|
||||
@ -525,24 +476,12 @@ class BasicNorm(torch.nn.Module):
|
||||
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
||||
else:
|
||||
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||
self.eps_min = eps_min
|
||||
self.eps_max = eps_max
|
||||
self.store_output_for_backprop = store_output_for_backprop
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
eps = self.eps
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
channel_dim = self.channel_dim
|
||||
if channel_dim < 0:
|
||||
channel_dim = channel_dim + x.ndim
|
||||
for _ in range(channel_dim + 1, x.ndim):
|
||||
bias = bias.unsqueeze(-1)
|
||||
scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
||||
return x * scales
|
||||
|
||||
if self.training and random.random() < 0.25:
|
||||
# with probability 0.25, in training mode, clamp eps between the min
|
||||
# and max; this will encourage it to learn parameters within the
|
||||
@ -552,9 +491,10 @@ class BasicNorm(torch.nn.Module):
|
||||
# gradients to allow the parameter to get back into the allowed
|
||||
# region if it happens to exit it.
|
||||
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
||||
|
||||
return BasicNormFunction.apply(x, self.bias, eps, self.channel_dim,
|
||||
self.store_output_for_backprop)
|
||||
scales = (
|
||||
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
|
||||
) ** -0.5
|
||||
return x * scales
|
||||
|
||||
|
||||
|
||||
|
||||
@ -451,7 +451,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
||||
|
||||
self.norm_final = BasicNorm(embed_dim, store_output_for_backprop=False)
|
||||
self.norm_final = BasicNorm(embed_dim)
|
||||
|
||||
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user