diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 9f6180eab..90acc99f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -34,7 +34,8 @@ from scaling import ( Whiten, Identity, _diag, - random_clamp + random_clamp, + softmax, ) from torch import Tensor, nn @@ -1148,7 +1149,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz * num_heads, seq_len, seq_len ) - attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = softmax(attn_output_weights, dim=-1) attn_output_weights = nn.functional.dropout( attn_output_weights, p=dropout_p, training=training ) @@ -1561,7 +1562,7 @@ class AttentionCombine(nn.Module): single_prob_mask) weights = weights.masked_fill(mask, float('-inf')) - weights = weights.softmax(dim=1) + weights = softmax(weights, dim=1) # (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1), ans = torch.matmul(stacked_inputs, weights.unsqueeze(2)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 87626b780..d8a380b9d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -185,6 +185,57 @@ def random_clamp(x: Tensor, return RandomClampFunction.apply(x, min, max, prob) +def random_cast_to_half(x: Tensor, + min_abs: float = 1.0e-03) -> Tensor: + """ + A randomized way of casting a floating point value to half precision. + """ + if x.dtype == torch.float16: + return x + x_sign = x.sign() + x_abs = x.abs() + is_too_small = (x_abs < min_abs) + # for elements where is_too_small is true, random_val will contain +-min_abs with + # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, + # for those elements]. + random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) + return torch.where(is_too_small, random_val, x).to(torch.float16) + + +class SoftmaxFunction(torch.autograd.Function): + """ + Tries to handle half-precision derivatives in a randomized way that should + be more accurate for training than the default behavior. + """ + @staticmethod + def forward(ctx, x: Tensor, dim: int): + ans = x.softmax(dim=dim) + ctx.save_for_backward(ans) + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor): + ans, = ctx.saved_tensors + + if ans.dtype == torch.float16 or ans_grad.dtype == torch.float16: + # use a randomized approach to convert to float16 + with torch.cuda.amp.autocast(enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return random_cast_to_half(x_grad), None + else: + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + +def softmax(x: Tensor, + dim: int): + return SoftmaxFunction.apply(x, dim) + class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod @@ -942,11 +993,24 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) +def _test_softmax(): + a = torch.randn(2, 10, dtype=torch.float64) + b = a.clone() + a.requires_grad = True + b.requires_grad = True + a.softmax(dim=1)[:,0].sum().backward() + print("a grad = ", a.grad) + softmax(b, dim=1)[:,0].sum().backward() + print("b grad = ", b.grad) + assert torch.allclose(a.grad, b.grad) + + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) + _test_softmax() _test_whiten() _test_max_eig() _test_activation_balancer_sign()