minor changes

This commit is contained in:
marcoyang 2024-08-12 11:30:21 +08:00
parent af67140ad2
commit a288d4123b

View File

@ -528,13 +528,6 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--full-bf16",
type=str2bool,
default=True,
help="If enabled, use pure bf16 training without using autocast and grad scaling"
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -870,7 +863,7 @@ def compute_loss(
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
if params.full_bf16: if params.use_bf16:
feature = feature.to(torch.bfloat16) feature = feature.to(torch.bfloat16)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
@ -1248,17 +1241,17 @@ def run(rank, world_size, args):
) )
if params.use_fp16: if params.use_fp16:
params.dtype = torch.float16 if not params.use_bf16 else torch.bfloat16 assert not params.use_bf16, "Only one of fp16 and bf16 can be used"
params.dtype = torch.float16
params.use_autocast = True params.use_autocast = True
else: elif params.use_bf16:
params.dtype = torch.float32 params.dtype = torch.bfloat16
params.use_autocast = False
logging.info(f"Training using: {params.dtype}")
model.to(params.dtype)
if params.full_bf16:
assert params.use_bf16
params.use_autocast = False # use full bf16 training, no autocast and grad scaling params.use_autocast = False # use full bf16 training, no autocast and grad scaling
model.to(params.dtype)
else:
raise NotImplementedError("Either fp16 or bf16 must be enabled")
logging.info(f"Training using: {params.dtype}")
model.to(device) model.to(device)
if world_size > 1: if world_size > 1: