Improve infinity-check (#1862)

1. Attach the inf-check hooks if the grad scale is getting too small.
2. Add try-catch to avoid OOM in the inf-check hooks.
3. Set warmup_start=0.1 to reduce chances of divergence
This commit is contained in:
Han Zhu 2025-01-09 15:05:38 +08:00 committed by GitHub
parent 8d602806c3
commit ab91112909
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 15 deletions

View File

@ -1165,23 +1165,34 @@ def train_one_epoch(
rank=rank,
)
if batch_idx % 100 == 0 and params.use_autocast:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
if params.use_autocast:
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
if not params.inf_check:
register_inf_check_hooks(model)
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise_grad_scale_is_too_small_error(cur_grad_scale)
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
if (
batch_idx % 25 == 0
and cur_grad_scale < 2.0
or batch_idx % 100 == 0
and cur_grad_scale < 8.0
or batch_idx % 400 == 0
and cur_grad_scale < 32.0
):
scaler.update(cur_grad_scale * 2.0)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0
@ -1335,7 +1346,7 @@ def run(rank, world_size, args):
clipping_scale=2.0,
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=0.1)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")

View File

@ -39,24 +39,34 @@ def register_inf_check_hooks(model: nn.Module) -> None:
# default param _name is a way to capture the current value of the variable "name".
def forward_hook(_module, _input, _output, _name=name):
if isinstance(_output, Tensor):
if not torch.isfinite(_output.to(torch.float32).sum()):
logging.warning(f"The sum of {_name}.output is not finite")
try:
if not torch.isfinite(_output.to(torch.float32).sum()):
logging.warning(f"The sum of {_name}.output is not finite")
except RuntimeError: # e.g. CUDA out of memory
pass
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
if isinstance(o, tuple):
o = o[0]
if not isinstance(o, Tensor):
continue
if not torch.isfinite(o.to(torch.float32).sum()):
logging.warning(f"The sum of {_name}.output[{i}] is not finite")
try:
if not torch.isfinite(o.to(torch.float32).sum()):
logging.warning(
f"The sum of {_name}.output[{i}] is not finite"
)
except RuntimeError: # e.g. CUDA out of memory
pass
# default param _name is a way to capture the current value of the variable "name".
def backward_hook(_module, _input, _output, _name=name):
if isinstance(_output, Tensor):
if not torch.isfinite(_output.to(torch.float32).sum()):
logging.warning(
f"The sum of {_name}.grad is not finite" # ": {_output}"
)
try:
if not torch.isfinite(_output.to(torch.float32).sum()):
logging.warning(f"The sum of {_name}.grad is not finite")
except RuntimeError: # e.g. CUDA out of memory
pass
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
if isinstance(o, tuple):