minor changes

This commit is contained in:
marcoyang 2024-08-08 10:54:44 +08:00
parent a7854dddba
commit af67140ad2

View File

@ -1008,7 +1008,6 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
try: try:
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
dtype = x_orig.dtype
x_detached = x_orig.detach() x_detached = x_orig.detach()
x_detached.requires_grad = True x_detached.requires_grad = True
@ -1028,7 +1027,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
metric.backward() metric.backward()
penalty_grad = x_detached.grad penalty_grad = x_detached.grad
scale = float(w.grad_scale) * ( scale = float(w.grad_scale) * (
x_grad.to(dtype).norm() x_grad.to(x_orig.dtype).norm()
/ (penalty_grad.norm() + 1.0e-20) / (penalty_grad.norm() + 1.0e-20)
) )
penalty_grad = penalty_grad * scale penalty_grad = penalty_grad * scale
@ -1391,10 +1390,8 @@ class SwooshL(torch.nn.Module):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
if not x.requires_grad: if not x.requires_grad:
# return k2.swoosh_l_forward(x)
return SwooshLForward(x) return SwooshLForward(x)
else: else:
# return k2.swoosh_l(x)
return SwooshLFunction.apply(x) # this support bf16 return SwooshLFunction.apply(x) # this support bf16
@ -1895,11 +1892,32 @@ def _test_activation_dropout_and_linear():
# storage of it. # storage of it.
assert isclose(x1.grad, x2.grad) assert isclose(x1.grad, x2.grad)
def _test_swoosh_bf16():
x_bf16 = torch.randn(1,100,128).to(torch.bfloat16)
x_fp32 = x_bf16.clone().to(torch.float32)
# test bf16 version
y_bf16 = SwooshLFunction.apply(x_bf16)
# test fp32 version
y_fp32 = SwooshLFunction.apply(x_fp32)
import pdb; pdb.set_trace()
diff_1 = torch.abs(y_fp32 - y_bf16).sum()
diff_2 = torch.abs(y_fp32 - y_fp32.to(torch.bfloat16)).sum()
print(diff_1, diff_2)
if __name__ == "__main__": if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
_test_swoosh_bf16()
_test_piecewise_linear() _test_piecewise_linear()
_test_softmax() _test_softmax()
_test_whiten() _test_whiten()