diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index a2200c04b..4ea5e0558 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -201,7 +201,6 @@ def random_cast_to_half(x: Tensor, """ if x.dtype == torch.float16: return x - x_sign = x.sign() x_abs = x.abs() is_too_small = (x_abs < min_abs) # for elements where is_too_small is true, random_val will contain +-min_abs with @@ -223,7 +222,6 @@ class RandomGradFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: - min_abs = ctx.min_abs if ans_grad.dtype == torch.float16: return random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs), None