diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index e7c3f4ab1..60a348010 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1032,7 +1032,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): w.prob = w.max_prob metric.backward() penalty_grad = x_detached.grad - scale = w.grad_scale * ( + scale = float(w.grad_scale) * ( x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) ) @@ -1074,7 +1074,7 @@ class Whiten(nn.Module): super(Whiten, self).__init__() assert num_groups >= 1 assert float(whitening_limit) >= 1 - assert grad_scale >= 0 + assert float(grad_scale) >= 0 self.num_groups = num_groups self.whitening_limit = whitening_limit self.grad_scale = grad_scale