mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Improve infinity-check (#1862)
1. Attach the inf-check hooks if the grad scale is getting too small. 2. Add try-catch to avoid OOM in the inf-check hooks. 3. Set warmup_start=0.1 to reduce chances of divergence
This commit is contained in:
parent
8d602806c3
commit
ab91112909
@ -1165,23 +1165,34 @@ def train_one_epoch(
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if batch_idx % 100 == 0 and params.use_autocast:
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
# of the grad scaler is configurable, but we can't configure it to have different
|
||||
# behavior depending on the current grad scale.
|
||||
if params.use_autocast:
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
|
||||
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
if cur_grad_scale < 0.01:
|
||||
if not saved_bad_model:
|
||||
save_bad_model(suffix="-first-warning")
|
||||
saved_bad_model = True
|
||||
if not params.inf_check:
|
||||
register_inf_check_hooks(model)
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
# of the grad scaler is configurable, but we can't configure it to have different
|
||||
# behavior depending on the current grad scale.
|
||||
if (
|
||||
batch_idx % 25 == 0
|
||||
and cur_grad_scale < 2.0
|
||||
or batch_idx % 100 == 0
|
||||
and cur_grad_scale < 8.0
|
||||
or batch_idx % 400 == 0
|
||||
and cur_grad_scale < 32.0
|
||||
):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0
|
||||
@ -1335,7 +1346,7 @@ def run(rank, world_size, args):
|
||||
clipping_scale=2.0,
|
||||
)
|
||||
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=0.1)
|
||||
|
||||
if checkpoints and "optimizer" in checkpoints:
|
||||
logging.info("Loading optimizer state dict")
|
||||
|
@ -39,24 +39,34 @@ def register_inf_check_hooks(model: nn.Module) -> None:
|
||||
# default param _name is a way to capture the current value of the variable "name".
|
||||
def forward_hook(_module, _input, _output, _name=name):
|
||||
if isinstance(_output, Tensor):
|
||||
if not torch.isfinite(_output.to(torch.float32).sum()):
|
||||
logging.warning(f"The sum of {_name}.output is not finite")
|
||||
try:
|
||||
if not torch.isfinite(_output.to(torch.float32).sum()):
|
||||
logging.warning(f"The sum of {_name}.output is not finite")
|
||||
except RuntimeError: # e.g. CUDA out of memory
|
||||
pass
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
if isinstance(o, tuple):
|
||||
o = o[0]
|
||||
if not isinstance(o, Tensor):
|
||||
continue
|
||||
if not torch.isfinite(o.to(torch.float32).sum()):
|
||||
logging.warning(f"The sum of {_name}.output[{i}] is not finite")
|
||||
try:
|
||||
if not torch.isfinite(o.to(torch.float32).sum()):
|
||||
logging.warning(
|
||||
f"The sum of {_name}.output[{i}] is not finite"
|
||||
)
|
||||
except RuntimeError: # e.g. CUDA out of memory
|
||||
pass
|
||||
|
||||
# default param _name is a way to capture the current value of the variable "name".
|
||||
def backward_hook(_module, _input, _output, _name=name):
|
||||
if isinstance(_output, Tensor):
|
||||
if not torch.isfinite(_output.to(torch.float32).sum()):
|
||||
logging.warning(
|
||||
f"The sum of {_name}.grad is not finite" # ": {_output}"
|
||||
)
|
||||
try:
|
||||
if not torch.isfinite(_output.to(torch.float32).sum()):
|
||||
logging.warning(f"The sum of {_name}.grad is not finite")
|
||||
except RuntimeError: # e.g. CUDA out of memory
|
||||
pass
|
||||
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
if isinstance(o, tuple):
|
||||
|
Loading…
x
Reference in New Issue
Block a user