mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement bias in BasicNorm
This commit is contained in:
parent
5aa874d8e3
commit
b39cde85c8
@ -430,6 +430,35 @@ class MaxEigLimiterFunction(torch.autograd.Function):
|
|||||||
return x_grad + x_extra_grad.detach(), None, None, None, None
|
return x_grad + x_extra_grad.detach(), None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class ComputeSquaredMeanWithOffset(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
@custom_fwd
|
||||||
|
def forward(ctx, x: Tensor, bias: Tensor, channel_dim: int) -> Tensor:
|
||||||
|
assert bias.ndim == 1
|
||||||
|
if channel_dim < 0:
|
||||||
|
channel_dim = channel_dim + x.ndim
|
||||||
|
ctx.channel_dim = channel_dim
|
||||||
|
for _ in range(channel_dim + 1, x.ndim):
|
||||||
|
bias = bias.unsqueeze(-1)
|
||||||
|
ans = torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True)
|
||||||
|
ctx.save_for_backward(x, bias)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
|
def backward(ctx, ans_grad: Tensor) -> Tensor:
|
||||||
|
x, bias = ctx.saved_tensors
|
||||||
|
x = x.detach()
|
||||||
|
bias = bias.detach()
|
||||||
|
x.requires_grad = True
|
||||||
|
bias.requires_grad = True
|
||||||
|
with torch.enable_grad():
|
||||||
|
ans = torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True)
|
||||||
|
ans.backward(gradient=ans_grad)
|
||||||
|
return x.grad, bias.grad.flatten(), None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BasicNorm(torch.nn.Module):
|
class BasicNorm(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
This is intended to be a simpler, and hopefully cheaper, replacement for
|
This is intended to be a simpler, and hopefully cheaper, replacement for
|
||||||
@ -446,7 +475,7 @@ class BasicNorm(torch.nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_channels: the number of channels, e.g. 512.
|
num_channels: the number of channels, e.g. 512.
|
||||||
channel_dim: the axis/dimension corresponding to the channel,
|
channel_dim: the axis/dimension corresponding to the channel,
|
||||||
interprted as an offset from the input's ndim if negative.
|
interprted as an offset from the input's ndim if negative.
|
||||||
shis is NOT the num_channels; it should typically be one of
|
shis is NOT the num_channels; it should typically be one of
|
||||||
{-2, -1, 0, 1, 2, 3}.
|
{-2, -1, 0, 1, 2, 3}.
|
||||||
@ -476,6 +505,7 @@ class BasicNorm(torch.nn.Module):
|
|||||||
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
||||||
else:
|
else:
|
||||||
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
||||||
|
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||||
self.eps_min = eps_min
|
self.eps_min = eps_min
|
||||||
self.eps_max = eps_max
|
self.eps_max = eps_max
|
||||||
|
|
||||||
@ -491,8 +521,10 @@ class BasicNorm(torch.nn.Module):
|
|||||||
# gradients to allow the parameter to get back into the allowed
|
# gradients to allow the parameter to get back into the allowed
|
||||||
# region if it happens to exit it.
|
# region if it happens to exit it.
|
||||||
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
||||||
|
|
||||||
|
norms = ComputeSquaredMeanWithOffset.apply(x, self.bias, self.channel_dim)
|
||||||
scales = (
|
scales = (
|
||||||
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
|
norms + eps.exp()
|
||||||
) ** -0.5
|
) ** -0.5
|
||||||
return x * scales
|
return x * scales
|
||||||
|
|
||||||
@ -1071,12 +1103,10 @@ class Whiten(nn.Module):
|
|||||||
self.grad_scale = grad_scale
|
self.grad_scale = grad_scale
|
||||||
|
|
||||||
if isinstance(prob, float):
|
if isinstance(prob, float):
|
||||||
assert 0 < prob <= 1
|
prob = (prob, prob)
|
||||||
self.prob = prob
|
(self.min_prob, self.max_prob) = prob
|
||||||
else:
|
assert 0 < self.min_prob <= self.max_prob <= 1
|
||||||
(self.min_prob, self.max_prob) = prob
|
self.prob = self.max_prob
|
||||||
assert 0 < self.min_prob < self.max_prob <= 1
|
|
||||||
self.prob = self.max_prob
|
|
||||||
self.name = None # will be set in training loop
|
self.name = None # will be set in training loop
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
@ -1718,7 +1748,7 @@ def _test_activation_balancer_magnitude():
|
|||||||
max_factor=0.2,
|
max_factor=0.2,
|
||||||
min_abs=0.2,
|
min_abs=0.2,
|
||||||
max_abs=0.8,
|
max_abs=0.8,
|
||||||
min_prob=1.0,
|
prob=1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user