Work out alpha (scale on z) in LinearWithAuxLossFunction

This commit is contained in:
Daniel Povey 2022-11-27 23:45:38 +08:00
parent 0307252832
commit 9e7add6be8

View File

@ -389,7 +389,7 @@ class BasicNorm(torch.nn.Module):
class LinearWithAuxLossFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, weight: Tensor, alpha: Tensor,
def forward(ctx, x: Tensor, weight: Tensor,
aux_grad_scale: float) -> Tensor:
"""
Returns matmul(x, weight.t()).
@ -398,14 +398,14 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
"""
if torch.is_autocast_enabled():
x = x.to(torch.float16)
ctx.save_for_backward(x, weight, alpha)
ctx.save_for_backward(x, weight)
ctx.aux_grad_scale = aux_grad_scale
return torch.matmul(x, weight.t())
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, None]:
x, weight, alpha = ctx.saved_tensors
x, weight = ctx.saved_tensors
x_grad = torch.matmul(ans_grad, weight.to(ans_grad.dtype))
weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(),
@ -415,14 +415,15 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
with torch.cuda.amp.autocast(enabled=False):
with torch.enable_grad():
x = x.to(weight.dtype)
x, weight, alpha = x.detach(), weight.detach(), alpha.detach()
x, weight = x.detach(), weight.detach()
weight.requires_grad = True
alpha.requires_grad = True
# recompute y as we need the gradient; this is easier to implement than
# saving y in the context.
y = torch.matmul(x, weight.t())
z = alpha.exp() * torch.matmul(y, weight)
diff = x - z
z = torch.matmul(y, weight)
with torch.no_grad():
alpha = (x * z).sum() / ((z * z).sum() + 1.0e-20)
diff = x - alpha * z
dims_to_mean = tuple(range(x.ndim-1))
mean = diff.mean(dim=dims_to_mean)
diff = diff - mean # subtract mean.
@ -430,7 +431,6 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
meansq = (diff ** 2).mean()
meansq.backward()
weight_aux_grad = weight.grad
alpha_grad = alpha.grad
with torch.cuda.amp.autocast(enabled=False):
@ -439,7 +439,7 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
weight_grad_scale = ctx.aux_grad_scale * weight_grad_norm / (aux_grad_norm + 1.0e-20)
weight_grad = weight_grad + (weight_grad_scale * weight_aux_grad).to(weight_grad.dtype)
return x_grad, weight_grad, alpha_grad, None
return x_grad, weight_grad, None
@ -485,7 +485,6 @@ class LinearWithAuxLoss(nn.Module):
0.01 * initial_scale)
else:
self.register_parameter('bias', None)
self.alpha = nn.Parameter(torch.tensor(0.0))
def forward(self,
@ -495,7 +494,7 @@ class LinearWithAuxLoss(nn.Module):
aux_grad_scale == 0.0 or random.random() > float(self.prob)):
return torch.nn.functional.linear(x, self.weight, self.bias)
else:
ans = LinearWithAuxLossFunction.apply(x, self.weight, self.alpha,
ans = LinearWithAuxLossFunction.apply(x, self.weight,
aux_grad_scale)
if self.bias is not None:
ans += self.bias