mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
minor changes
This commit is contained in:
parent
a7854dddba
commit
af67140ad2
@ -1008,7 +1008,6 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
|||||||
try:
|
try:
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
dtype = x_orig.dtype
|
|
||||||
x_detached = x_orig.detach()
|
x_detached = x_orig.detach()
|
||||||
x_detached.requires_grad = True
|
x_detached.requires_grad = True
|
||||||
|
|
||||||
@ -1028,7 +1027,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
|||||||
metric.backward()
|
metric.backward()
|
||||||
penalty_grad = x_detached.grad
|
penalty_grad = x_detached.grad
|
||||||
scale = float(w.grad_scale) * (
|
scale = float(w.grad_scale) * (
|
||||||
x_grad.to(dtype).norm()
|
x_grad.to(x_orig.dtype).norm()
|
||||||
/ (penalty_grad.norm() + 1.0e-20)
|
/ (penalty_grad.norm() + 1.0e-20)
|
||||||
)
|
)
|
||||||
penalty_grad = penalty_grad * scale
|
penalty_grad = penalty_grad * scale
|
||||||
@ -1391,10 +1390,8 @@ class SwooshL(torch.nn.Module):
|
|||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
||||||
if not x.requires_grad:
|
if not x.requires_grad:
|
||||||
# return k2.swoosh_l_forward(x)
|
|
||||||
return SwooshLForward(x)
|
return SwooshLForward(x)
|
||||||
else:
|
else:
|
||||||
# return k2.swoosh_l(x)
|
|
||||||
return SwooshLFunction.apply(x) # this support bf16
|
return SwooshLFunction.apply(x) # this support bf16
|
||||||
|
|
||||||
|
|
||||||
@ -1895,11 +1892,32 @@ def _test_activation_dropout_and_linear():
|
|||||||
# storage of it.
|
# storage of it.
|
||||||
assert isclose(x1.grad, x2.grad)
|
assert isclose(x1.grad, x2.grad)
|
||||||
|
|
||||||
|
def _test_swoosh_bf16():
|
||||||
|
|
||||||
|
x_bf16 = torch.randn(1,100,128).to(torch.bfloat16)
|
||||||
|
x_fp32 = x_bf16.clone().to(torch.float32)
|
||||||
|
|
||||||
|
# test bf16 version
|
||||||
|
y_bf16 = SwooshLFunction.apply(x_bf16)
|
||||||
|
|
||||||
|
# test fp32 version
|
||||||
|
y_fp32 = SwooshLFunction.apply(x_fp32)
|
||||||
|
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
|
|
||||||
|
diff_1 = torch.abs(y_fp32 - y_bf16).sum()
|
||||||
|
diff_2 = torch.abs(y_fp32 - y_fp32.to(torch.bfloat16)).sum()
|
||||||
|
|
||||||
|
print(diff_1, diff_2)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
_test_swoosh_bf16()
|
||||||
_test_piecewise_linear()
|
_test_piecewise_linear()
|
||||||
_test_softmax()
|
_test_softmax()
|
||||||
_test_whiten()
|
_test_whiten()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user