Fix problem with mean offset in LinearWithAuxLoss.

This commit is contained in:
Daniel Povey 2022-11-28 09:46:01 +08:00
parent a3b07fd098
commit 109825cafb

View File

@ -421,12 +421,13 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
# saving y in the context. # saving y in the context.
y = torch.matmul(x, weight.t()) y = torch.matmul(x, weight.t())
z = torch.matmul(y, weight) z = torch.matmul(y, weight)
# subtract mean
x = x - x.mean(dim=dims_to_mean)
z = z - z.mean(dim=dims_to_mean)
# compute optimal scale on z
with torch.no_grad(): with torch.no_grad():
alpha = (x * z).sum() / ((z * z).sum() + 1.0e-20) alpha = (x * z).sum() / ((z * z).sum() + 1.0e-20)
diff = x - alpha * z diff = x - alpha * z
dims_to_mean = tuple(range(x.ndim-1))
mean = diff.mean(dim=dims_to_mean)
diff = diff - mean # subtract mean.
# meansq is the loss function. # meansq is the loss function.
meansq = (diff ** 2).mean() meansq = (diff ** 2).mean()
meansq.backward() meansq.backward()