mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove eps from BasicNorm and reintroduce bias
This commit is contained in:
parent
a2227a07fc
commit
3a5b3f640d
@ -434,42 +434,44 @@ class MaxEigLimiterFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
class BasicNormFunction(torch.autograd.Function):
|
class BasicNormFunction(torch.autograd.Function):
|
||||||
# This computes:
|
# This computes:
|
||||||
# scales = ((torch.mean(x**2, keepdim=True) + eps) ** -0.5 * scale)
|
# scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
|
||||||
# return (x - bias) * scales
|
# return (x - bias) * scales
|
||||||
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
|
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
|
||||||
# it can just store the returned value (chances are, this will also be needed for
|
# it can just store the returned value (chances are, this will also be needed for
|
||||||
# some other reason, related to the next operation, so we can save memory).
|
# some other reason, related to the next operation, so we can save memory).
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor, eps: Tensor, scale: Tensor, channel_dim: int,
|
def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int,
|
||||||
store_output_for_backprop: bool) -> Tensor:
|
store_output_for_backprop: bool) -> Tensor:
|
||||||
|
assert bias.ndim == 1
|
||||||
|
if channel_dim < 0:
|
||||||
|
channel_dim = channel_dim + x.ndim
|
||||||
ctx.store_output_for_backprop = store_output_for_backprop
|
ctx.store_output_for_backprop = store_output_for_backprop
|
||||||
ctx.channel_dim = channel_dim
|
ctx.channel_dim = channel_dim
|
||||||
scales = ((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + eps) ** -0.5) * scale
|
for _ in range(channel_dim + 1, x.ndim):
|
||||||
|
bias = bias.unsqueeze(-1)
|
||||||
|
scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp()
|
||||||
ans = x * scales
|
ans = x * scales
|
||||||
ctx.save_for_backward(ans.detach() if store_output_for_backprop else x,
|
ctx.save_for_backward(ans.detach() if store_output_for_backprop else x,
|
||||||
scales.detach(), eps.detach(), scale.detach())
|
scales.detach(), bias.detach(), log_scale.detach())
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, ans_grad: Tensor) -> Tensor:
|
def backward(ctx, ans_grad: Tensor) -> Tensor:
|
||||||
ans_or_x, scales, eps, scale = ctx.saved_tensors
|
ans_or_x, scales, bias, log_scale = ctx.saved_tensors
|
||||||
if ctx.store_output_for_backprop:
|
if ctx.store_output_for_backprop:
|
||||||
x = ans_or_x / scales
|
x = ans_or_x / scales
|
||||||
else:
|
else:
|
||||||
x = ans_or_x
|
x = ans_or_x
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
x = x.detach()
|
||||||
assert eps.dtype != torch.float16 and scale.dtype != torch.float16
|
|
||||||
x = x.to(torch.float32).detach()
|
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
eps.requires_grad = True
|
bias.requires_grad = True
|
||||||
scale.requires_grad = True
|
log_scale.requires_grad = True
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
# recompute scales from x, eps and log_scale.
|
# recompute scales from x, bias and log_scale.
|
||||||
scales = ((torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) + eps) ** -0.5) * scale
|
scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp()
|
||||||
ans = x * scales
|
ans = x * scales
|
||||||
ans.backward(gradient=ans_grad.to(torch.float32))
|
ans.backward(gradient=ans_grad)
|
||||||
return x.grad.to(ans_grad.dtype), eps.grad, scale.grad, None, None
|
return x.grad, bias.grad.flatten(), log_scale.grad, None, None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -495,11 +497,8 @@ class BasicNorm(torch.nn.Module):
|
|||||||
{-2, -1, 0, 1, 2, 3}.
|
{-2, -1, 0, 1, 2, 3}.
|
||||||
log_scale: the initial log-scale that we multiply the output by; this
|
log_scale: the initial log-scale that we multiply the output by; this
|
||||||
is learnable.
|
is learnable.
|
||||||
eps: the initial epsilon value (not in log space)
|
|
||||||
log_scale_min: FloatLike, minimum allowed value of log_scale
|
log_scale_min: FloatLike, minimum allowed value of log_scale
|
||||||
log_scale_max: FloatLike, maximum allowed value of log_scale
|
log_scale_max: FloatLike, maximum allowed value of log_scale
|
||||||
log_eps_min: FloatLike, minimum allowed value of log_eps
|
|
||||||
log_eps_max: FloatLike, maximum allowed value of log_eps
|
|
||||||
store_output_for_backprop: only possibly affects memory use; recommend
|
store_output_for_backprop: only possibly affects memory use; recommend
|
||||||
to set to True if you think the output of this module is more likely
|
to set to True if you think the output of this module is more likely
|
||||||
than the input of this module to be required to be stored for the
|
than the input of this module to be required to be stored for the
|
||||||
@ -511,25 +510,19 @@ class BasicNorm(torch.nn.Module):
|
|||||||
num_channels: int,
|
num_channels: int,
|
||||||
channel_dim: int = -1, # CAUTION: see documentation.
|
channel_dim: int = -1, # CAUTION: see documentation.
|
||||||
log_scale: float = 1.0,
|
log_scale: float = 1.0,
|
||||||
eps: float = 0.1,
|
log_scale_min: float = -1.5,
|
||||||
log_scale_min: FloatLike = -1.5,
|
log_scale_max: float = 1.5,
|
||||||
log_scale_max: FloatLike = 1.5,
|
|
||||||
log_eps_min: FloatLike = -3.0,
|
|
||||||
log_eps_max: FloatLike = 3.0,
|
|
||||||
store_output_for_backprop: bool = False
|
store_output_for_backprop: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
super(BasicNorm, self).__init__()
|
super(BasicNorm, self).__init__()
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
self.log_scale = nn.Parameter(torch.tensor(log_scale))
|
self.log_scale = nn.Parameter(torch.tensor(log_scale))
|
||||||
self.log_eps = nn.Parameter(torch.tensor(eps).log().detach())
|
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||||
|
|
||||||
self.log_scale_min = log_scale_min
|
self.log_scale_min = log_scale_min
|
||||||
self.log_scale_max = log_scale_max
|
self.log_scale_max = log_scale_max
|
||||||
|
|
||||||
self.log_eps_min = log_eps_min
|
|
||||||
self.log_eps_max = log_eps_max
|
|
||||||
|
|
||||||
self.store_output_for_backprop = store_output_for_backprop
|
self.store_output_for_backprop = store_output_for_backprop
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
@ -538,7 +531,12 @@ class BasicNorm(torch.nn.Module):
|
|||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting():
|
||||||
channel_dim = self.channel_dim
|
channel_dim = self.channel_dim
|
||||||
scales = (((torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + self.log_eps.exp()) ** -0.5) *
|
if channel_dim < 0:
|
||||||
|
channel_dim += x.ndim
|
||||||
|
bias = self.bias
|
||||||
|
for _ in range(channel_dim + 1, x.ndim):
|
||||||
|
bias = bias.unsqueeze(-1)
|
||||||
|
scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) *
|
||||||
self.log_scale.exp())
|
self.log_scale.exp())
|
||||||
return x * scales
|
return x * scales
|
||||||
|
|
||||||
@ -546,12 +544,8 @@ class BasicNorm(torch.nn.Module):
|
|||||||
min=float(self.log_scale_min),
|
min=float(self.log_scale_min),
|
||||||
max=float(self.log_scale_max),
|
max=float(self.log_scale_max),
|
||||||
training=self.training)
|
training=self.training)
|
||||||
log_eps = limit_param_value(self.log_eps,
|
|
||||||
min=float(self.log_eps_min),
|
|
||||||
max=float(self.log_eps_max),
|
|
||||||
training=self.training)
|
|
||||||
|
|
||||||
return BasicNormFunction.apply(x, log_eps.exp(), log_scale.exp(),
|
return BasicNormFunction.apply(x, self.bias, log_scale,
|
||||||
self.channel_dim,
|
self.channel_dim,
|
||||||
self.store_output_for_backprop)
|
self.store_output_for_backprop)
|
||||||
|
|
||||||
|
|||||||
@ -1885,8 +1885,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
|
|
||||||
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
|
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
|
||||||
# getting large, there is an unnecessary degree of freedom.
|
# getting large, there is an unnecessary degree of freedom.
|
||||||
self.out_norm = BasicNorm(out_channels, eps=1.0,
|
self.out_norm = BasicNorm(out_channels)
|
||||||
log_eps_min=-0.1, log_eps_max=0.0)
|
|
||||||
self.dropout = Dropout2(dropout)
|
self.dropout = Dropout2(dropout)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user