mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
support AMP training with bf16
This commit is contained in:
parent
d1974befaa
commit
a7854dddba
@ -297,7 +297,7 @@ class SoftmaxFunction(torch.autograd.Function):
|
||||
# (presumably) that op does not support float16, and autocast
|
||||
# is enabled.
|
||||
if torch.is_autocast_enabled():
|
||||
ans = ans.to(torch.float16)
|
||||
ans = ans.to(torch.get_autocast_gpu_dtype())
|
||||
ctx.save_for_backward(ans)
|
||||
ctx.x_dtype = x.dtype
|
||||
ctx.dim = dim
|
||||
@ -1234,7 +1234,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor) -> Tensor:
|
||||
requires_grad = x.requires_grad
|
||||
if x.dtype == torch.float16:
|
||||
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
s = torch.sigmoid(x - 1.0)
|
||||
@ -1346,7 +1346,7 @@ class SwooshLFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor) -> Tensor:
|
||||
requires_grad = x.requires_grad
|
||||
if x.dtype == torch.float16:
|
||||
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
@ -1379,7 +1379,7 @@ class SwooshLFunction(torch.autograd.Function):
|
||||
d_int = d_scaled.to(torch.uint8)
|
||||
ctx.save_for_backward(d_int)
|
||||
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
||||
y = y.to(torch.float16)
|
||||
y = y.to(torch.get_autocast_gpu_dtype())
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
@ -1425,7 +1425,7 @@ class SwooshRFunction(torch.autograd.Function):
|
||||
def forward(ctx, x: Tensor) -> Tensor:
|
||||
requires_grad = x.requires_grad
|
||||
|
||||
if x.dtype == torch.float16:
|
||||
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
@ -1455,7 +1455,7 @@ class SwooshRFunction(torch.autograd.Function):
|
||||
d_int = d_scaled.to(torch.uint8)
|
||||
ctx.save_for_backward(d_int)
|
||||
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
||||
y = y.to(torch.float16)
|
||||
y = y.to(torch.get_autocast_gpu_dtype())
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
|
@ -521,6 +521,13 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-bf16",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use bf16 in AMP.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -1027,7 +1034,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16, dtype=params.dtype):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1207,6 +1214,13 @@ def run(rank, world_size, args):
|
||||
params.ctc_loss_scale, params.attention_decoder_loss_scale
|
||||
)
|
||||
|
||||
if params.use_bf16:
|
||||
assert torch.cuda.is_bf16_supported(), f"Your GPU does not support bf16!"
|
||||
params.dtype = torch.bfloat16
|
||||
else:
|
||||
params.dtype = torch.float16
|
||||
logging.info(f"Using dtype={params.dtype}")
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
@ -1439,7 +1453,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16, dtype=params.dtype):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
Loading…
x
Reference in New Issue
Block a user