From af67140ad2e71fe6239208d5c53f6332849f1243 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Thu, 8 Aug 2024 10:54:44 +0800 Subject: [PATCH] minor changes --- egs/librispeech/ASR/zipformer/scaling_bf16.py | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling_bf16.py b/egs/librispeech/ASR/zipformer/scaling_bf16.py index 77b431203..a0ddc6b8a 100644 --- a/egs/librispeech/ASR/zipformer/scaling_bf16.py +++ b/egs/librispeech/ASR/zipformer/scaling_bf16.py @@ -1008,7 +1008,6 @@ class WhiteningPenaltyFunction(torch.autograd.Function): try: with torch.enable_grad(): with torch.cuda.amp.autocast(enabled=False): - dtype = x_orig.dtype x_detached = x_orig.detach() x_detached.requires_grad = True @@ -1028,7 +1027,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): metric.backward() penalty_grad = x_detached.grad scale = float(w.grad_scale) * ( - x_grad.to(dtype).norm() + x_grad.to(x_orig.dtype).norm() / (penalty_grad.norm() + 1.0e-20) ) penalty_grad = penalty_grad * scale @@ -1391,10 +1390,8 @@ class SwooshL(torch.nn.Module): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 if not x.requires_grad: - # return k2.swoosh_l_forward(x) return SwooshLForward(x) else: - # return k2.swoosh_l(x) return SwooshLFunction.apply(x) # this support bf16 @@ -1895,11 +1892,32 @@ def _test_activation_dropout_and_linear(): # storage of it. assert isclose(x1.grad, x2.grad) +def _test_swoosh_bf16(): + + x_bf16 = torch.randn(1,100,128).to(torch.bfloat16) + x_fp32 = x_bf16.clone().to(torch.float32) + + # test bf16 version + y_bf16 = SwooshLFunction.apply(x_bf16) + + # test fp32 version + y_fp32 = SwooshLFunction.apply(x_fp32) + + import pdb; pdb.set_trace() + + diff_1 = torch.abs(y_fp32 - y_bf16).sum() + diff_2 = torch.abs(y_fp32 - y_fp32.to(torch.bfloat16)).sum() + + print(diff_1, diff_2) + + + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) + _test_swoosh_bf16() _test_piecewise_linear() _test_softmax() _test_whiten()