diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index eb556b4e9..84c408c12 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -491,10 +491,15 @@ class LinearWithAuxLoss(nn.Module): aux_grad_scale = float(self.aux_grad_scale) if (not self.training or torch.jit.is_scripting() or aux_grad_scale == 0.0 or random.random() > float(self.prob)): - return torch.matmul(x, self.weight.t()) + self.bias + ans = torch.matmul(x, self.weight.t()) else: - return LinearWithAuxLossFunction.apply(x, self.weight, self.alpha, - aux_grad_scale) + self.bias + ans = LinearWithAuxLossFunction.apply(x, self.weight, self.alpha, + aux_grad_scale) + if self.bias is None: + return ans + else: + return ans + self.bias + def ScaledLinear(*args, initial_scale: float = 1.0,