From 19048e155b5a07b4dd4d1795815a1a0cd4584e25 Mon Sep 17 00:00:00 2001 From: Teo Wen Shen <36886809+teowenshen@users.noreply.github.com> Date: Thu, 11 Jul 2024 16:12:30 +0900 Subject: [PATCH] Cast grad_scale in whiten to float (#1663) * cast grad_scale in whiten to float * fix cast in zipformer_lora --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- egs/librispeech/ASR/zipformer_lora/scaling.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 3c7e0fa4e..164cc7bfd 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1033,7 +1033,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): w.prob = w.max_prob metric.backward() penalty_grad = x_detached.grad - scale = w.grad_scale * ( + scale = float(w.grad_scale) * ( x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) ) @@ -1075,7 +1075,7 @@ class Whiten(nn.Module): super(Whiten, self).__init__() assert num_groups >= 1 assert float(whitening_limit) >= 1 - assert grad_scale >= 0 + assert float(grad_scale) >= 0 self.num_groups = num_groups self.whitening_limit = whitening_limit self.grad_scale = grad_scale diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py index 3149db9f3..8d7aa8027 100644 --- a/egs/librispeech/ASR/zipformer_lora/scaling.py +++ b/egs/librispeech/ASR/zipformer_lora/scaling.py @@ -1137,7 +1137,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): w.prob = w.max_prob metric.backward() penalty_grad = x_detached.grad - scale = w.grad_scale * ( + scale = float(w.grad_scale) * ( x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) ) @@ -1179,7 +1179,7 @@ class Whiten(nn.Module): super(Whiten, self).__init__() assert num_groups >= 1 assert float(whitening_limit) >= 1 - assert grad_scale >= 0 + assert float(grad_scale) >= 0 self.num_groups = num_groups self.whitening_limit = whitening_limit self.grad_scale = grad_scale