From 109825cafbeabbee49e0328ba0be21fbc9f33b1b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 28 Nov 2022 09:46:01 +0800 Subject: [PATCH] Fix problem with mean offset in LinearWithAuxLoss. --- .../ASR/pruned_transducer_stateless7/scaling.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index a70450aa4..49e204de1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -421,12 +421,13 @@ class LinearWithAuxLossFunction(torch.autograd.Function): # saving y in the context. y = torch.matmul(x, weight.t()) z = torch.matmul(y, weight) + # subtract mean + x = x - x.mean(dim=dims_to_mean) + z = z - z.mean(dim=dims_to_mean) + # compute optimal scale on z with torch.no_grad(): alpha = (x * z).sum() / ((z * z).sum() + 1.0e-20) diff = x - alpha * z - dims_to_mean = tuple(range(x.ndim-1)) - mean = diff.mean(dim=dims_to_mean) - diff = diff - mean # subtract mean. # meansq is the loss function. meansq = (diff ** 2).mean() meansq.backward()