Fix to LinearWithAuxLoss for bias=False case

This commit is contained in:
Daniel Povey 2022-11-26 14:16:53 +08:00
parent 5f80807027
commit 7b5c0382f9

View File

@ -491,10 +491,15 @@ class LinearWithAuxLoss(nn.Module):
aux_grad_scale = float(self.aux_grad_scale)
if (not self.training or torch.jit.is_scripting() or
aux_grad_scale == 0.0 or random.random() > float(self.prob)):
return torch.matmul(x, self.weight.t()) + self.bias
ans = torch.matmul(x, self.weight.t())
else:
return LinearWithAuxLossFunction.apply(x, self.weight, self.alpha,
aux_grad_scale) + self.bias
ans = LinearWithAuxLossFunction.apply(x, self.weight, self.alpha,
aux_grad_scale)
if self.bias is None:
return ans
else:
return ans + self.bias
def ScaledLinear(*args,
initial_scale: float = 1.0,