fix bug RE dims_to_mean
This commit is contained in:
parent
109825cafb
commit
0bfd81d721
@ -422,6 +422,7 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
|
||||
y = torch.matmul(x, weight.t())
|
||||
z = torch.matmul(y, weight)
|
||||
# subtract mean
|
||||
dims_to_mean = tuple(range(x.ndim-1))
|
||||
x = x - x.mean(dim=dims_to_mean)
|
||||
z = z - z.mean(dim=dims_to_mean)
|
||||
# compute optimal scale on z
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user