mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove some unused variables.
This commit is contained in:
parent
78f3cba58c
commit
b9f6ba1aa2
@ -201,7 +201,6 @@ def random_cast_to_half(x: Tensor,
|
|||||||
"""
|
"""
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16:
|
||||||
return x
|
return x
|
||||||
x_sign = x.sign()
|
|
||||||
x_abs = x.abs()
|
x_abs = x.abs()
|
||||||
is_too_small = (x_abs < min_abs)
|
is_too_small = (x_abs < min_abs)
|
||||||
# for elements where is_too_small is true, random_val will contain +-min_abs with
|
# for elements where is_too_small is true, random_val will contain +-min_abs with
|
||||||
@ -223,7 +222,6 @@ class RandomGradFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
|
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
|
||||||
min_abs = ctx.min_abs
|
|
||||||
if ans_grad.dtype == torch.float16:
|
if ans_grad.dtype == torch.float16:
|
||||||
return random_cast_to_half(ans_grad.to(torch.float32),
|
return random_cast_to_half(ans_grad.to(torch.float32),
|
||||||
min_abs=ctx.min_abs), None
|
min_abs=ctx.min_abs), None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user