Remove eps from BasicNorm and reintroduce bias

This commit is contained in:
Daniel Povey 2023-01-02 00:02:31 +08:00
parent a2227a07fc
commit 3a5b3f640d
2 changed files with 35 additions and 42 deletions

View File

@ -434,42 +434,44 @@ class MaxEigLimiterFunction(torch.autograd.Function):
class BasicNormFunction(torch.autograd.Function): class BasicNormFunction(torch.autograd.Function):
# This computes: # This computes:
# scales = ((torch.mean(x**2, keepdim=True) + eps) ** -0.5 * scale) # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
# return (x - bias) * scales # return (x - bias) * scales
# (after unsqueezing the bias), but it does it in a memory-efficient way so that # (after unsqueezing the bias), but it does it in a memory-efficient way so that
# it can just store the returned value (chances are, this will also be needed for # it can just store the returned value (chances are, this will also be needed for
# some other reason, related to the next operation, so we can save memory). # some other reason, related to the next operation, so we can save memory).
@staticmethod @staticmethod
def forward(ctx, x: Tensor, eps: Tensor, scale: Tensor, channel_dim: int, def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int,
store_output_for_backprop: bool) -> Tensor: store_output_for_backprop: bool) -> Tensor:
assert bias.ndim == 1
if channel_dim < 0:
channel_dim = channel_dim + x.ndim
ctx.store_output_for_backprop = store_output_for_backprop ctx.store_output_for_backprop = store_output_for_backprop
ctx.channel_dim = channel_dim ctx.channel_dim = channel_dim
scales = ((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + eps) ** -0.5) * scale for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1)
scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp()
ans = x * scales ans = x * scales
ctx.save_for_backward(ans.detach() if store_output_for_backprop else x, ctx.save_for_backward(ans.detach() if store_output_for_backprop else x,
scales.detach(), eps.detach(), scale.detach()) scales.detach(), bias.detach(), log_scale.detach())
return ans return ans
@staticmethod @staticmethod
def backward(ctx, ans_grad: Tensor) -> Tensor: def backward(ctx, ans_grad: Tensor) -> Tensor:
ans_or_x, scales, eps, scale = ctx.saved_tensors ans_or_x, scales, bias, log_scale = ctx.saved_tensors
if ctx.store_output_for_backprop: if ctx.store_output_for_backprop:
x = ans_or_x / scales x = ans_or_x / scales
else: else:
x = ans_or_x x = ans_or_x
with torch.cuda.amp.autocast(enabled=False): x = x.detach()
assert eps.dtype != torch.float16 and scale.dtype != torch.float16
x = x.to(torch.float32).detach()
x.requires_grad = True x.requires_grad = True
eps.requires_grad = True bias.requires_grad = True
scale.requires_grad = True log_scale.requires_grad = True
with torch.enable_grad(): with torch.enable_grad():
# recompute scales from x, eps and log_scale. # recompute scales from x, bias and log_scale.
scales = ((torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) + eps) ** -0.5) * scale scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp()
ans = x * scales ans = x * scales
ans.backward(gradient=ans_grad.to(torch.float32)) ans.backward(gradient=ans_grad)
return x.grad.to(ans_grad.dtype), eps.grad, scale.grad, None, None return x.grad, bias.grad.flatten(), log_scale.grad, None, None
@ -495,11 +497,8 @@ class BasicNorm(torch.nn.Module):
{-2, -1, 0, 1, 2, 3}. {-2, -1, 0, 1, 2, 3}.
log_scale: the initial log-scale that we multiply the output by; this log_scale: the initial log-scale that we multiply the output by; this
is learnable. is learnable.
eps: the initial epsilon value (not in log space)
log_scale_min: FloatLike, minimum allowed value of log_scale log_scale_min: FloatLike, minimum allowed value of log_scale
log_scale_max: FloatLike, maximum allowed value of log_scale log_scale_max: FloatLike, maximum allowed value of log_scale
log_eps_min: FloatLike, minimum allowed value of log_eps
log_eps_max: FloatLike, maximum allowed value of log_eps
store_output_for_backprop: only possibly affects memory use; recommend store_output_for_backprop: only possibly affects memory use; recommend
to set to True if you think the output of this module is more likely to set to True if you think the output of this module is more likely
than the input of this module to be required to be stored for the than the input of this module to be required to be stored for the
@ -511,25 +510,19 @@ class BasicNorm(torch.nn.Module):
num_channels: int, num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation. channel_dim: int = -1, # CAUTION: see documentation.
log_scale: float = 1.0, log_scale: float = 1.0,
eps: float = 0.1, log_scale_min: float = -1.5,
log_scale_min: FloatLike = -1.5, log_scale_max: float = 1.5,
log_scale_max: FloatLike = 1.5,
log_eps_min: FloatLike = -3.0,
log_eps_max: FloatLike = 3.0,
store_output_for_backprop: bool = False store_output_for_backprop: bool = False
) -> None: ) -> None:
super(BasicNorm, self).__init__() super(BasicNorm, self).__init__()
self.num_channels = num_channels self.num_channels = num_channels
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.log_scale = nn.Parameter(torch.tensor(log_scale)) self.log_scale = nn.Parameter(torch.tensor(log_scale))
self.log_eps = nn.Parameter(torch.tensor(eps).log().detach()) self.bias = nn.Parameter(torch.zeros(num_channels))
self.log_scale_min = log_scale_min self.log_scale_min = log_scale_min
self.log_scale_max = log_scale_max self.log_scale_max = log_scale_max
self.log_eps_min = log_eps_min
self.log_eps_max = log_eps_max
self.store_output_for_backprop = store_output_for_backprop self.store_output_for_backprop = store_output_for_backprop
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
@ -538,7 +531,12 @@ class BasicNorm(torch.nn.Module):
if torch.jit.is_scripting(): if torch.jit.is_scripting():
channel_dim = self.channel_dim channel_dim = self.channel_dim
scales = (((torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + self.log_eps.exp()) ** -0.5) * if channel_dim < 0:
channel_dim += x.ndim
bias = self.bias
for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1)
scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) *
self.log_scale.exp()) self.log_scale.exp())
return x * scales return x * scales
@ -546,12 +544,8 @@ class BasicNorm(torch.nn.Module):
min=float(self.log_scale_min), min=float(self.log_scale_min),
max=float(self.log_scale_max), max=float(self.log_scale_max),
training=self.training) training=self.training)
log_eps = limit_param_value(self.log_eps,
min=float(self.log_eps_min),
max=float(self.log_eps_max),
training=self.training)
return BasicNormFunction.apply(x, log_eps.exp(), log_scale.exp(), return BasicNormFunction.apply(x, self.bias, log_scale,
self.channel_dim, self.channel_dim,
self.store_output_for_backprop) self.store_output_for_backprop)

View File

@ -1885,8 +1885,7 @@ class Conv2dSubsampling(nn.Module):
# max_log_eps=0.0 is to prevent both eps and the output of self.out from # max_log_eps=0.0 is to prevent both eps and the output of self.out from
# getting large, there is an unnecessary degree of freedom. # getting large, there is an unnecessary degree of freedom.
self.out_norm = BasicNorm(out_channels, eps=1.0, self.out_norm = BasicNorm(out_channels)
log_eps_min=-0.1, log_eps_max=0.0)
self.dropout = Dropout2(dropout) self.dropout = Dropout2(dropout)