From eb9800e1218f2adc00906ab1abeef177e1a267bb Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 24 Jul 2024 15:05:13 +0800 Subject: [PATCH] minor updates --- egs/librispeech/ASR/zipformer/scaling_bf16.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling_bf16.py b/egs/librispeech/ASR/zipformer/scaling_bf16.py index 4b82c3acd..77b431203 100644 --- a/egs/librispeech/ASR/zipformer/scaling_bf16.py +++ b/egs/librispeech/ASR/zipformer/scaling_bf16.py @@ -786,17 +786,10 @@ class BalancerFunction(torch.autograd.Function): loss_grad = loss_grad * (grad_scale / loss_grad_rms) - # if x_grad.dtype == torch.float16: - # x_grad_float = x_grad.to(torch.float32) - # else: - # x_grad_float = x_grad - # scale each element of loss_grad by the absolute value of the corresponding # element of x_grad, which we view as a noisy estimate of its magnitude for that # (frame and dimension). later we can consider factored versions. - # x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) x_grad = x_grad + (x_grad.abs() * loss_grad) - # x_grad = x_grad_mod.to(x_grad.dtype) except Exception as e: logging.info( f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue."