diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 49e204de1..f685bf112 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -422,6 +422,7 @@ class LinearWithAuxLossFunction(torch.autograd.Function): y = torch.matmul(x, weight.t()) z = torch.matmul(y, weight) # subtract mean + dims_to_mean = tuple(range(x.ndim-1)) x = x - x.mean(dim=dims_to_mean) z = z - z.mean(dim=dims_to_mean) # compute optimal scale on z