From a288d4123b6e8e6a2d30c6c22f3c270f614f2571 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Mon, 12 Aug 2024 11:30:21 +0800 Subject: [PATCH] minor changes --- .../ASR/zipformer/train_full_bf16.py | 27 +++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train_full_bf16.py b/egs/librispeech/ASR/zipformer/train_full_bf16.py index b4f9372b5..64722b309 100644 --- a/egs/librispeech/ASR/zipformer/train_full_bf16.py +++ b/egs/librispeech/ASR/zipformer/train_full_bf16.py @@ -528,13 +528,6 @@ def get_parser(): help="Whether to use half precision training.", ) - parser.add_argument( - "--full-bf16", - type=str2bool, - default=True, - help="If enabled, use pure bf16 training without using autocast and grad scaling" - ) - add_model_arguments(parser) return parser @@ -870,7 +863,7 @@ def compute_loss( # at entry, feature is (N, T, C) assert feature.ndim == 3 feature = feature.to(device) - if params.full_bf16: + if params.use_bf16: feature = feature.to(torch.bfloat16) supervisions = batch["supervisions"] @@ -1248,17 +1241,17 @@ def run(rank, world_size, args): ) if params.use_fp16: - params.dtype = torch.float16 if not params.use_bf16 else torch.bfloat16 + assert not params.use_bf16, "Only one of fp16 and bf16 can be used" + params.dtype = torch.float16 params.use_autocast = True - else: - params.dtype = torch.float32 - params.use_autocast = False - logging.info(f"Training using: {params.dtype}") - model.to(params.dtype) - - if params.full_bf16: - assert params.use_bf16 + elif params.use_bf16: + params.dtype = torch.bfloat16 params.use_autocast = False # use full bf16 training, no autocast and grad scaling + model.to(params.dtype) + else: + raise NotImplementedError("Either fp16 or bf16 must be enabled") + + logging.info(f"Training using: {params.dtype}") model.to(device) if world_size > 1: