mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix problem with mean offset in LinearWithAuxLoss.
This commit is contained in:
parent
a3b07fd098
commit
109825cafb
@ -421,12 +421,13 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
|
|||||||
# saving y in the context.
|
# saving y in the context.
|
||||||
y = torch.matmul(x, weight.t())
|
y = torch.matmul(x, weight.t())
|
||||||
z = torch.matmul(y, weight)
|
z = torch.matmul(y, weight)
|
||||||
|
# subtract mean
|
||||||
|
x = x - x.mean(dim=dims_to_mean)
|
||||||
|
z = z - z.mean(dim=dims_to_mean)
|
||||||
|
# compute optimal scale on z
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
alpha = (x * z).sum() / ((z * z).sum() + 1.0e-20)
|
alpha = (x * z).sum() / ((z * z).sum() + 1.0e-20)
|
||||||
diff = x - alpha * z
|
diff = x - alpha * z
|
||||||
dims_to_mean = tuple(range(x.ndim-1))
|
|
||||||
mean = diff.mean(dim=dims_to_mean)
|
|
||||||
diff = diff - mean # subtract mean.
|
|
||||||
# meansq is the loss function.
|
# meansq is the loss function.
|
||||||
meansq = (diff ** 2).mean()
|
meansq = (diff ** 2).mean()
|
||||||
meansq.backward()
|
meansq.backward()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user