diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 8b136bf10..a70450aa4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -389,7 +389,7 @@ class BasicNorm(torch.nn.Module): class LinearWithAuxLossFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, weight: Tensor, alpha: Tensor, + def forward(ctx, x: Tensor, weight: Tensor, aux_grad_scale: float) -> Tensor: """ Returns matmul(x, weight.t()). @@ -398,14 +398,14 @@ class LinearWithAuxLossFunction(torch.autograd.Function): """ if torch.is_autocast_enabled(): x = x.to(torch.float16) - ctx.save_for_backward(x, weight, alpha) + ctx.save_for_backward(x, weight) ctx.aux_grad_scale = aux_grad_scale return torch.matmul(x, weight.t()) @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, None]: - x, weight, alpha = ctx.saved_tensors + x, weight = ctx.saved_tensors x_grad = torch.matmul(ans_grad, weight.to(ans_grad.dtype)) weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(), @@ -415,14 +415,15 @@ class LinearWithAuxLossFunction(torch.autograd.Function): with torch.cuda.amp.autocast(enabled=False): with torch.enable_grad(): x = x.to(weight.dtype) - x, weight, alpha = x.detach(), weight.detach(), alpha.detach() + x, weight = x.detach(), weight.detach() weight.requires_grad = True - alpha.requires_grad = True # recompute y as we need the gradient; this is easier to implement than # saving y in the context. y = torch.matmul(x, weight.t()) - z = alpha.exp() * torch.matmul(y, weight) - diff = x - z + z = torch.matmul(y, weight) + with torch.no_grad(): + alpha = (x * z).sum() / ((z * z).sum() + 1.0e-20) + diff = x - alpha * z dims_to_mean = tuple(range(x.ndim-1)) mean = diff.mean(dim=dims_to_mean) diff = diff - mean # subtract mean. @@ -430,7 +431,6 @@ class LinearWithAuxLossFunction(torch.autograd.Function): meansq = (diff ** 2).mean() meansq.backward() weight_aux_grad = weight.grad - alpha_grad = alpha.grad with torch.cuda.amp.autocast(enabled=False): @@ -439,7 +439,7 @@ class LinearWithAuxLossFunction(torch.autograd.Function): weight_grad_scale = ctx.aux_grad_scale * weight_grad_norm / (aux_grad_norm + 1.0e-20) weight_grad = weight_grad + (weight_grad_scale * weight_aux_grad).to(weight_grad.dtype) - return x_grad, weight_grad, alpha_grad, None + return x_grad, weight_grad, None @@ -485,7 +485,6 @@ class LinearWithAuxLoss(nn.Module): 0.01 * initial_scale) else: self.register_parameter('bias', None) - self.alpha = nn.Parameter(torch.tensor(0.0)) def forward(self, @@ -495,7 +494,7 @@ class LinearWithAuxLoss(nn.Module): aux_grad_scale == 0.0 or random.random() > float(self.prob)): return torch.nn.functional.linear(x, self.weight, self.bias) else: - ans = LinearWithAuxLossFunction.apply(x, self.weight, self.alpha, + ans = LinearWithAuxLossFunction.apply(x, self.weight, aux_grad_scale) if self.bias is not None: ans += self.bias