diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 3c7e0fa4e..164cc7bfd 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1033,7 +1033,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): w.prob = w.max_prob metric.backward() penalty_grad = x_detached.grad - scale = w.grad_scale * ( + scale = float(w.grad_scale) * ( x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) ) @@ -1075,7 +1075,7 @@ class Whiten(nn.Module): super(Whiten, self).__init__() assert num_groups >= 1 assert float(whitening_limit) >= 1 - assert grad_scale >= 0 + assert float(grad_scale) >= 0 self.num_groups = num_groups self.whitening_limit = whitening_limit self.grad_scale = grad_scale diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py index 3149db9f3..8d7aa8027 100644 --- a/egs/librispeech/ASR/zipformer_lora/scaling.py +++ b/egs/librispeech/ASR/zipformer_lora/scaling.py @@ -1137,7 +1137,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): w.prob = w.max_prob metric.backward() penalty_grad = x_detached.grad - scale = w.grad_scale * ( + scale = float(w.grad_scale) * ( x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) ) @@ -1179,7 +1179,7 @@ class Whiten(nn.Module): super(Whiten, self).__init__() assert num_groups >= 1 assert float(whitening_limit) >= 1 - assert grad_scale >= 0 + assert float(grad_scale) >= 0 self.num_groups = num_groups self.whitening_limit = whitening_limit self.grad_scale = grad_scale