mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix problem with mean offset in LinearWithAuxLoss.
This commit is contained in:
parent
a3b07fd098
commit
109825cafb
@ -421,12 +421,13 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
|
||||
# saving y in the context.
|
||||
y = torch.matmul(x, weight.t())
|
||||
z = torch.matmul(y, weight)
|
||||
# subtract mean
|
||||
x = x - x.mean(dim=dims_to_mean)
|
||||
z = z - z.mean(dim=dims_to_mean)
|
||||
# compute optimal scale on z
|
||||
with torch.no_grad():
|
||||
alpha = (x * z).sum() / ((z * z).sum() + 1.0e-20)
|
||||
diff = x - alpha * z
|
||||
dims_to_mean = tuple(range(x.ndim-1))
|
||||
mean = diff.mean(dim=dims_to_mean)
|
||||
diff = diff - mean # subtract mean.
|
||||
# meansq is the loss function.
|
||||
meansq = (diff ** 2).mean()
|
||||
meansq.backward()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user