support AMP training with bf16

This commit is contained in:
marcoyang 2024-08-08 10:13:30 +08:00
parent d1974befaa
commit a7854dddba
2 changed files with 22 additions and 8 deletions

View File

@ -297,7 +297,7 @@ class SoftmaxFunction(torch.autograd.Function):
# (presumably) that op does not support float16, and autocast
# is enabled.
if torch.is_autocast_enabled():
ans = ans.to(torch.float16)
ans = ans.to(torch.get_autocast_gpu_dtype())
ctx.save_for_backward(ans)
ctx.x_dtype = x.dtype
ctx.dim = dim
@ -1234,7 +1234,7 @@ class DoubleSwishFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad
if x.dtype == torch.float16:
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
x = x.to(torch.float32)
s = torch.sigmoid(x - 1.0)
@ -1346,7 +1346,7 @@ class SwooshLFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad
if x.dtype == torch.float16:
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
x = x.to(torch.float32)
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
@ -1379,7 +1379,7 @@ class SwooshLFunction(torch.autograd.Function):
d_int = d_scaled.to(torch.uint8)
ctx.save_for_backward(d_int)
if x.dtype == torch.float16 or torch.is_autocast_enabled():
y = y.to(torch.float16)
y = y.to(torch.get_autocast_gpu_dtype())
return y
@staticmethod
@ -1425,7 +1425,7 @@ class SwooshRFunction(torch.autograd.Function):
def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad
if x.dtype == torch.float16:
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
x = x.to(torch.float32)
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
@ -1455,7 +1455,7 @@ class SwooshRFunction(torch.autograd.Function):
d_int = d_scaled.to(torch.uint8)
ctx.save_for_backward(d_int)
if x.dtype == torch.float16 or torch.is_autocast_enabled():
y = y.to(torch.float16)
y = y.to(torch.get_autocast_gpu_dtype())
return y
@staticmethod

View File

@ -521,6 +521,13 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--use-bf16",
type=str2bool,
default=False,
help="Whether to use bf16 in AMP.",
)
add_model_arguments(parser)
return parser
@ -1027,7 +1034,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch.cuda.amp.autocast(enabled=params.use_fp16, dtype=params.dtype):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1207,6 +1214,13 @@ def run(rank, world_size, args):
params.ctc_loss_scale, params.attention_decoder_loss_scale
)
if params.use_bf16:
assert torch.cuda.is_bf16_supported(), f"Your GPU does not support bf16!"
params.dtype = torch.bfloat16
else:
params.dtype = torch.float16
logging.info(f"Using dtype={params.dtype}")
logging.info(params)
logging.info("About to create model")
@ -1439,7 +1453,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch.cuda.amp.autocast(enabled=params.use_fp16, dtype=params.dtype):
loss, _ = compute_loss(
params=params,
model=model,