diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 22fc2608a..10069a0ac 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -274,33 +274,43 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): ans, = ctx.saved_tensors - try: + + def get_x_grad(y, y_grad, dim): with torch.cuda.amp.autocast(enabled=False): - if ans.dtype != torch.float32: - ans = ans.to(torch.float32) - x_grad = ans.mul_(ans_grad.to(torch.float32)) - else: - # out-of-place since it's not a copy - x_grad = ans_grad.to(torch.float32) * ans - ans *= x_grad.sum(dim=ctx.dim, keepdim=True) - x_grad -= ans - return x_grad, None + y_grad = y_grad.to(torch.float32) + y = y.to(torch.float32) + x_grad = y_grad * y + y *= x_grad.sum(dim=dim, keepdim=True) + x_grad -= y + return x_grad + + + + try: + if __name__ == '__main__': + raise RuntimeError("For testing") + x_grad = get_x_grad(ans, ans_grad, ctx.dim) + return x_grad, None except Exception as e: - logging.info(f"Caught exception in SoftmaxFunction backward: {e}, size={list(ans.shape)}, dim={ctx.dim}, will try in half precision.") - x_grad = None + logging.info(f"Caught exception in SoftmaxFunction backward: {e}, size={list(ans.shape)}, dim={ctx.dim}, will try another method.") + dim = ctx.dim + if dim < 0: + dim = dim + ans.dim + split_dim = 0 if dim != 0 else 1 + # split_dim is the dimension we split up ans on. + num_split = min(8, ans.shape[split_dim]) + x_grad_split = [ + get_x_grad(ans_part, ans_grad_part, ctx.dim) for ans_part, ans_grad_part in + zip(torch.split(ans, num_split, dim=split_dim), torch.split(ans_grad, num_split, dim=split_dim)) + ] + x_grad = torch.cat(x_grad_split, dim=split_dim) - ans, = ctx.saved_tensors - ans_grad.mul_(ans) - x_grad = ans_grad - ans *= x_grad.sum(dim=ctx.dim, keepdim=True) - x_grad -= ans return x_grad, None - def softmax(x: Tensor, dim: int): return SoftmaxFunction.apply(x, dim)