From a7854dddbad4da962aee1f3c7c5d5713bcf58fb8 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Thu, 8 Aug 2024 10:13:30 +0800 Subject: [PATCH] support AMP training with bf16 --- egs/librispeech/ASR/zipformer/scaling.py | 12 ++++++------ egs/librispeech/ASR/zipformer/train.py | 18 ++++++++++++++++-- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 164cc7bfd..2a40b8d64 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -297,7 +297,7 @@ class SoftmaxFunction(torch.autograd.Function): # (presumably) that op does not support float16, and autocast # is enabled. if torch.is_autocast_enabled(): - ans = ans.to(torch.float16) + ans = ans.to(torch.get_autocast_gpu_dtype()) ctx.save_for_backward(ans) ctx.x_dtype = x.dtype ctx.dim = dim @@ -1234,7 +1234,7 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad - if x.dtype == torch.float16: + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: x = x.to(torch.float32) s = torch.sigmoid(x - 1.0) @@ -1346,7 +1346,7 @@ class SwooshLFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad - if x.dtype == torch.float16: + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: x = x.to(torch.float32) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) @@ -1379,7 +1379,7 @@ class SwooshLFunction(torch.autograd.Function): d_int = d_scaled.to(torch.uint8) ctx.save_for_backward(d_int) if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) + y = y.to(torch.get_autocast_gpu_dtype()) return y @staticmethod @@ -1425,7 +1425,7 @@ class SwooshRFunction(torch.autograd.Function): def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad - if x.dtype == torch.float16: + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: x = x.to(torch.float32) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) @@ -1455,7 +1455,7 @@ class SwooshRFunction(torch.autograd.Function): d_int = d_scaled.to(torch.uint8) ctx.save_for_backward(d_int) if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) + y = y.to(torch.get_autocast_gpu_dtype()) return y @staticmethod diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 9b6f4a93a..95e31e6e6 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -521,6 +521,13 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + add_model_arguments(parser) return parser @@ -1027,7 +1034,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16, dtype=params.dtype): loss, loss_info = compute_loss( params=params, model=model, @@ -1207,6 +1214,13 @@ def run(rank, world_size, args): params.ctc_loss_scale, params.attention_decoder_loss_scale ) + if params.use_bf16: + assert torch.cuda.is_bf16_supported(), f"Your GPU does not support bf16!" + params.dtype = torch.bfloat16 + else: + params.dtype = torch.float16 + logging.info(f"Using dtype={params.dtype}") + logging.info(params) logging.info("About to create model") @@ -1439,7 +1453,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16, dtype=params.dtype): loss, _ = compute_loss( params=params, model=model,