fix bug RE dims_to_mean

This commit is contained in:
Daniel Povey 2022-11-28 10:42:06 +08:00
parent 109825cafb
commit 0bfd81d721

View File

@ -422,6 +422,7 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
y = torch.matmul(x, weight.t())
z = torch.matmul(y, weight)
# subtract mean
dims_to_mean = tuple(range(x.ndim-1))
x = x - x.mean(dim=dims_to_mean)
z = z - z.mean(dim=dims_to_mean)
# compute optimal scale on z