From 110c2601abe352c400ba66eb6fd85c74499c197a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 26 Nov 2022 14:35:26 +0800 Subject: [PATCH] Changes for speed --- .../ASR/pruned_transducer_stateless7/scaling.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index acc0defa6..db341a1c9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -493,14 +493,13 @@ class LinearWithAuxLoss(nn.Module): aux_grad_scale = float(self.aux_grad_scale) if (not self.training or torch.jit.is_scripting() or aux_grad_scale == 0.0 or random.random() > float(self.prob)): - ans = torch.matmul(x, self.weight.t()) + return torch.nn.functional.linear(x, self.weight, self.bias) else: ans = LinearWithAuxLossFunction.apply(x, self.weight, self.alpha, aux_grad_scale) - if self.bias is None: + if self.bias is not None: + ans += self.bias return ans - else: - return ans + self.bias def ScaledLinear(*args,