Changes for speed

This commit is contained in:
Daniel Povey 2022-11-26 14:35:26 +08:00
parent c653c66413
commit 110c2601ab

View File

@ -493,14 +493,13 @@ class LinearWithAuxLoss(nn.Module):
aux_grad_scale = float(self.aux_grad_scale)
if (not self.training or torch.jit.is_scripting() or
aux_grad_scale == 0.0 or random.random() > float(self.prob)):
ans = torch.matmul(x, self.weight.t())
return torch.nn.functional.linear(x, self.weight, self.bias)
else:
ans = LinearWithAuxLossFunction.apply(x, self.weight, self.alpha,
aux_grad_scale)
if self.bias is None:
if self.bias is not None:
ans += self.bias
return ans
else:
return ans + self.bias
def ScaledLinear(*args,