mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
minor changes
This commit is contained in:
parent
af67140ad2
commit
a288d4123b
@ -528,13 +528,6 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--full-bf16",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="If enabled, use pure bf16 training without using autocast and grad scaling"
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -870,7 +863,7 @@ def compute_loss(
|
|||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
if params.full_bf16:
|
if params.use_bf16:
|
||||||
feature = feature.to(torch.bfloat16)
|
feature = feature.to(torch.bfloat16)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
@ -1248,17 +1241,17 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if params.use_fp16:
|
if params.use_fp16:
|
||||||
params.dtype = torch.float16 if not params.use_bf16 else torch.bfloat16
|
assert not params.use_bf16, "Only one of fp16 and bf16 can be used"
|
||||||
|
params.dtype = torch.float16
|
||||||
params.use_autocast = True
|
params.use_autocast = True
|
||||||
else:
|
elif params.use_bf16:
|
||||||
params.dtype = torch.float32
|
params.dtype = torch.bfloat16
|
||||||
params.use_autocast = False
|
|
||||||
logging.info(f"Training using: {params.dtype}")
|
|
||||||
model.to(params.dtype)
|
|
||||||
|
|
||||||
if params.full_bf16:
|
|
||||||
assert params.use_bf16
|
|
||||||
params.use_autocast = False # use full bf16 training, no autocast and grad scaling
|
params.use_autocast = False # use full bf16 training, no autocast and grad scaling
|
||||||
|
model.to(params.dtype)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Either fp16 or bf16 must be enabled")
|
||||||
|
|
||||||
|
logging.info(f"Training using: {params.dtype}")
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user