mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
minor changes
This commit is contained in:
parent
a7854dddba
commit
af67140ad2
@ -1008,7 +1008,6 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
||||
try:
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
dtype = x_orig.dtype
|
||||
x_detached = x_orig.detach()
|
||||
x_detached.requires_grad = True
|
||||
|
||||
@ -1028,7 +1027,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
||||
metric.backward()
|
||||
penalty_grad = x_detached.grad
|
||||
scale = float(w.grad_scale) * (
|
||||
x_grad.to(dtype).norm()
|
||||
x_grad.to(x_orig.dtype).norm()
|
||||
/ (penalty_grad.norm() + 1.0e-20)
|
||||
)
|
||||
penalty_grad = penalty_grad * scale
|
||||
@ -1391,10 +1390,8 @@ class SwooshL(torch.nn.Module):
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
||||
if not x.requires_grad:
|
||||
# return k2.swoosh_l_forward(x)
|
||||
return SwooshLForward(x)
|
||||
else:
|
||||
# return k2.swoosh_l(x)
|
||||
return SwooshLFunction.apply(x) # this support bf16
|
||||
|
||||
|
||||
@ -1895,11 +1892,32 @@ def _test_activation_dropout_and_linear():
|
||||
# storage of it.
|
||||
assert isclose(x1.grad, x2.grad)
|
||||
|
||||
def _test_swoosh_bf16():
|
||||
|
||||
x_bf16 = torch.randn(1,100,128).to(torch.bfloat16)
|
||||
x_fp32 = x_bf16.clone().to(torch.float32)
|
||||
|
||||
# test bf16 version
|
||||
y_bf16 = SwooshLFunction.apply(x_bf16)
|
||||
|
||||
# test fp32 version
|
||||
y_fp32 = SwooshLFunction.apply(x_fp32)
|
||||
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
diff_1 = torch.abs(y_fp32 - y_bf16).sum()
|
||||
diff_2 = torch.abs(y_fp32 - y_fp32.to(torch.bfloat16)).sum()
|
||||
|
||||
print(diff_1, diff_2)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_swoosh_bf16()
|
||||
_test_piecewise_linear()
|
||||
_test_softmax()
|
||||
_test_whiten()
|
||||
|
Loading…
x
Reference in New Issue
Block a user