From 7b5c0382f9d52e7178abe44a576f4cb09668bb24 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 26 Nov 2022 14:16:53 +0800 Subject: [PATCH] Fix to LinearWithAuxLoss for bias=False case --- .../ASR/pruned_transducer_stateless7/scaling.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index eb556b4e9..84c408c12 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -491,10 +491,15 @@ class LinearWithAuxLoss(nn.Module): aux_grad_scale = float(self.aux_grad_scale) if (not self.training or torch.jit.is_scripting() or aux_grad_scale == 0.0 or random.random() > float(self.prob)): - return torch.matmul(x, self.weight.t()) + self.bias + ans = torch.matmul(x, self.weight.t()) else: - return LinearWithAuxLossFunction.apply(x, self.weight, self.alpha, - aux_grad_scale) + self.bias + ans = LinearWithAuxLossFunction.apply(x, self.weight, self.alpha, + aux_grad_scale) + if self.bias is None: + return ans + else: + return ans + self.bias + def ScaledLinear(*args, initial_scale: float = 1.0,