mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement bias in BasicNorm
This commit is contained in:
parent
5aa874d8e3
commit
b39cde85c8
@ -430,6 +430,35 @@ class MaxEigLimiterFunction(torch.autograd.Function):
|
||||
return x_grad + x_extra_grad.detach(), None, None, None, None
|
||||
|
||||
|
||||
class ComputeSquaredMeanWithOffset(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x: Tensor, bias: Tensor, channel_dim: int) -> Tensor:
|
||||
assert bias.ndim == 1
|
||||
if channel_dim < 0:
|
||||
channel_dim = channel_dim + x.ndim
|
||||
ctx.channel_dim = channel_dim
|
||||
for _ in range(channel_dim + 1, x.ndim):
|
||||
bias = bias.unsqueeze(-1)
|
||||
ans = torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True)
|
||||
ctx.save_for_backward(x, bias)
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, ans_grad: Tensor) -> Tensor:
|
||||
x, bias = ctx.saved_tensors
|
||||
x = x.detach()
|
||||
bias = bias.detach()
|
||||
x.requires_grad = True
|
||||
bias.requires_grad = True
|
||||
with torch.enable_grad():
|
||||
ans = torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True)
|
||||
ans.backward(gradient=ans_grad)
|
||||
return x.grad, bias.grad.flatten(), None
|
||||
|
||||
|
||||
|
||||
class BasicNorm(torch.nn.Module):
|
||||
"""
|
||||
This is intended to be a simpler, and hopefully cheaper, replacement for
|
||||
@ -476,6 +505,7 @@ class BasicNorm(torch.nn.Module):
|
||||
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
||||
else:
|
||||
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||
self.eps_min = eps_min
|
||||
self.eps_max = eps_max
|
||||
|
||||
@ -491,8 +521,10 @@ class BasicNorm(torch.nn.Module):
|
||||
# gradients to allow the parameter to get back into the allowed
|
||||
# region if it happens to exit it.
|
||||
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
||||
|
||||
norms = ComputeSquaredMeanWithOffset.apply(x, self.bias, self.channel_dim)
|
||||
scales = (
|
||||
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
|
||||
norms + eps.exp()
|
||||
) ** -0.5
|
||||
return x * scales
|
||||
|
||||
@ -1071,11 +1103,9 @@ class Whiten(nn.Module):
|
||||
self.grad_scale = grad_scale
|
||||
|
||||
if isinstance(prob, float):
|
||||
assert 0 < prob <= 1
|
||||
self.prob = prob
|
||||
else:
|
||||
prob = (prob, prob)
|
||||
(self.min_prob, self.max_prob) = prob
|
||||
assert 0 < self.min_prob < self.max_prob <= 1
|
||||
assert 0 < self.min_prob <= self.max_prob <= 1
|
||||
self.prob = self.max_prob
|
||||
self.name = None # will be set in training loop
|
||||
|
||||
@ -1718,7 +1748,7 @@ def _test_activation_balancer_magnitude():
|
||||
max_factor=0.2,
|
||||
min_abs=0.2,
|
||||
max_abs=0.8,
|
||||
min_prob=1.0,
|
||||
prob=1.0,
|
||||
)
|
||||
|
||||
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user