mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Work out alpha (scale on z) in LinearWithAuxLossFunction
This commit is contained in:
parent
0307252832
commit
9e7add6be8
@ -389,7 +389,7 @@ class BasicNorm(torch.nn.Module):
|
||||
|
||||
class LinearWithAuxLossFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, weight: Tensor, alpha: Tensor,
|
||||
def forward(ctx, x: Tensor, weight: Tensor,
|
||||
aux_grad_scale: float) -> Tensor:
|
||||
"""
|
||||
Returns matmul(x, weight.t()).
|
||||
@ -398,14 +398,14 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
|
||||
"""
|
||||
if torch.is_autocast_enabled():
|
||||
x = x.to(torch.float16)
|
||||
ctx.save_for_backward(x, weight, alpha)
|
||||
ctx.save_for_backward(x, weight)
|
||||
ctx.aux_grad_scale = aux_grad_scale
|
||||
return torch.matmul(x, weight.t())
|
||||
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, None]:
|
||||
x, weight, alpha = ctx.saved_tensors
|
||||
x, weight = ctx.saved_tensors
|
||||
|
||||
x_grad = torch.matmul(ans_grad, weight.to(ans_grad.dtype))
|
||||
weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(),
|
||||
@ -415,14 +415,15 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch.enable_grad():
|
||||
x = x.to(weight.dtype)
|
||||
x, weight, alpha = x.detach(), weight.detach(), alpha.detach()
|
||||
x, weight = x.detach(), weight.detach()
|
||||
weight.requires_grad = True
|
||||
alpha.requires_grad = True
|
||||
# recompute y as we need the gradient; this is easier to implement than
|
||||
# saving y in the context.
|
||||
y = torch.matmul(x, weight.t())
|
||||
z = alpha.exp() * torch.matmul(y, weight)
|
||||
diff = x - z
|
||||
z = torch.matmul(y, weight)
|
||||
with torch.no_grad():
|
||||
alpha = (x * z).sum() / ((z * z).sum() + 1.0e-20)
|
||||
diff = x - alpha * z
|
||||
dims_to_mean = tuple(range(x.ndim-1))
|
||||
mean = diff.mean(dim=dims_to_mean)
|
||||
diff = diff - mean # subtract mean.
|
||||
@ -430,7 +431,6 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
|
||||
meansq = (diff ** 2).mean()
|
||||
meansq.backward()
|
||||
weight_aux_grad = weight.grad
|
||||
alpha_grad = alpha.grad
|
||||
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
@ -439,7 +439,7 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
|
||||
weight_grad_scale = ctx.aux_grad_scale * weight_grad_norm / (aux_grad_norm + 1.0e-20)
|
||||
weight_grad = weight_grad + (weight_grad_scale * weight_aux_grad).to(weight_grad.dtype)
|
||||
|
||||
return x_grad, weight_grad, alpha_grad, None
|
||||
return x_grad, weight_grad, None
|
||||
|
||||
|
||||
|
||||
@ -485,7 +485,6 @@ class LinearWithAuxLoss(nn.Module):
|
||||
0.01 * initial_scale)
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.alpha = nn.Parameter(torch.tensor(0.0))
|
||||
|
||||
|
||||
def forward(self,
|
||||
@ -495,7 +494,7 @@ class LinearWithAuxLoss(nn.Module):
|
||||
aux_grad_scale == 0.0 or random.random() > float(self.prob)):
|
||||
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||
else:
|
||||
ans = LinearWithAuxLossFunction.apply(x, self.weight, self.alpha,
|
||||
ans = LinearWithAuxLossFunction.apply(x, self.weight,
|
||||
aux_grad_scale)
|
||||
if self.bias is not None:
|
||||
ans += self.bias
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user