mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
minor changes
This commit is contained in:
parent
af67140ad2
commit
a288d4123b
@ -528,13 +528,6 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--full-bf16",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="If enabled, use pure bf16 training without using autocast and grad scaling"
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -870,7 +863,7 @@ def compute_loss(
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
if params.full_bf16:
|
||||
if params.use_bf16:
|
||||
feature = feature.to(torch.bfloat16)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
@ -1248,17 +1241,17 @@ def run(rank, world_size, args):
|
||||
)
|
||||
|
||||
if params.use_fp16:
|
||||
params.dtype = torch.float16 if not params.use_bf16 else torch.bfloat16
|
||||
assert not params.use_bf16, "Only one of fp16 and bf16 can be used"
|
||||
params.dtype = torch.float16
|
||||
params.use_autocast = True
|
||||
else:
|
||||
params.dtype = torch.float32
|
||||
params.use_autocast = False
|
||||
logging.info(f"Training using: {params.dtype}")
|
||||
model.to(params.dtype)
|
||||
|
||||
if params.full_bf16:
|
||||
assert params.use_bf16
|
||||
elif params.use_bf16:
|
||||
params.dtype = torch.bfloat16
|
||||
params.use_autocast = False # use full bf16 training, no autocast and grad scaling
|
||||
model.to(params.dtype)
|
||||
else:
|
||||
raise NotImplementedError("Either fp16 or bf16 must be enabled")
|
||||
|
||||
logging.info(f"Training using: {params.dtype}")
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
|
Loading…
x
Reference in New Issue
Block a user