mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge branch 'k2-fsa:master' into dev/k2ssl
This commit is contained in:
commit
54d0a2b499
@ -1165,23 +1165,34 @@ def train_one_epoch(
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % 100 == 0 and params.use_autocast:
|
if params.use_autocast:
|
||||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
|
||||||
# of the grad scaler is configurable, but we can't configure it to have different
|
|
||||||
# behavior depending on the current grad scale.
|
|
||||||
cur_grad_scale = scaler._scale.item()
|
cur_grad_scale = scaler._scale.item()
|
||||||
|
|
||||||
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
|
|
||||||
scaler.update(cur_grad_scale * 2.0)
|
|
||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
if not saved_bad_model:
|
if not saved_bad_model:
|
||||||
save_bad_model(suffix="-first-warning")
|
save_bad_model(suffix="-first-warning")
|
||||||
saved_bad_model = True
|
saved_bad_model = True
|
||||||
|
if not params.inf_check:
|
||||||
|
register_inf_check_hooks(model)
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
|
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
|
|
||||||
|
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||||
|
# of the grad scaler is configurable, but we can't configure it to have different
|
||||||
|
# behavior depending on the current grad scale.
|
||||||
|
if (
|
||||||
|
batch_idx % 25 == 0
|
||||||
|
and cur_grad_scale < 2.0
|
||||||
|
or batch_idx % 100 == 0
|
||||||
|
and cur_grad_scale < 8.0
|
||||||
|
or batch_idx % 400 == 0
|
||||||
|
and cur_grad_scale < 32.0
|
||||||
|
):
|
||||||
|
scaler.update(cur_grad_scale * 2.0)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = max(scheduler.get_last_lr())
|
cur_lr = max(scheduler.get_last_lr())
|
||||||
cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0
|
cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0
|
||||||
@ -1335,7 +1346,7 @@ def run(rank, world_size, args):
|
|||||||
clipping_scale=2.0,
|
clipping_scale=2.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=0.1)
|
||||||
|
|
||||||
if checkpoints and "optimizer" in checkpoints:
|
if checkpoints and "optimizer" in checkpoints:
|
||||||
logging.info("Loading optimizer state dict")
|
logging.info("Loading optimizer state dict")
|
||||||
|
@ -39,24 +39,34 @@ def register_inf_check_hooks(model: nn.Module) -> None:
|
|||||||
# default param _name is a way to capture the current value of the variable "name".
|
# default param _name is a way to capture the current value of the variable "name".
|
||||||
def forward_hook(_module, _input, _output, _name=name):
|
def forward_hook(_module, _input, _output, _name=name):
|
||||||
if isinstance(_output, Tensor):
|
if isinstance(_output, Tensor):
|
||||||
if not torch.isfinite(_output.to(torch.float32).sum()):
|
try:
|
||||||
logging.warning(f"The sum of {_name}.output is not finite")
|
if not torch.isfinite(_output.to(torch.float32).sum()):
|
||||||
|
logging.warning(f"The sum of {_name}.output is not finite")
|
||||||
|
except RuntimeError: # e.g. CUDA out of memory
|
||||||
|
pass
|
||||||
elif isinstance(_output, tuple):
|
elif isinstance(_output, tuple):
|
||||||
for i, o in enumerate(_output):
|
for i, o in enumerate(_output):
|
||||||
if isinstance(o, tuple):
|
if isinstance(o, tuple):
|
||||||
o = o[0]
|
o = o[0]
|
||||||
if not isinstance(o, Tensor):
|
if not isinstance(o, Tensor):
|
||||||
continue
|
continue
|
||||||
if not torch.isfinite(o.to(torch.float32).sum()):
|
try:
|
||||||
logging.warning(f"The sum of {_name}.output[{i}] is not finite")
|
if not torch.isfinite(o.to(torch.float32).sum()):
|
||||||
|
logging.warning(
|
||||||
|
f"The sum of {_name}.output[{i}] is not finite"
|
||||||
|
)
|
||||||
|
except RuntimeError: # e.g. CUDA out of memory
|
||||||
|
pass
|
||||||
|
|
||||||
# default param _name is a way to capture the current value of the variable "name".
|
# default param _name is a way to capture the current value of the variable "name".
|
||||||
def backward_hook(_module, _input, _output, _name=name):
|
def backward_hook(_module, _input, _output, _name=name):
|
||||||
if isinstance(_output, Tensor):
|
if isinstance(_output, Tensor):
|
||||||
if not torch.isfinite(_output.to(torch.float32).sum()):
|
try:
|
||||||
logging.warning(
|
if not torch.isfinite(_output.to(torch.float32).sum()):
|
||||||
f"The sum of {_name}.grad is not finite" # ": {_output}"
|
logging.warning(f"The sum of {_name}.grad is not finite")
|
||||||
)
|
except RuntimeError: # e.g. CUDA out of memory
|
||||||
|
pass
|
||||||
|
|
||||||
elif isinstance(_output, tuple):
|
elif isinstance(_output, tuple):
|
||||||
for i, o in enumerate(_output):
|
for i, o in enumerate(_output):
|
||||||
if isinstance(o, tuple):
|
if isinstance(o, tuple):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user