mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
skip OOM
This commit is contained in:
parent
9939c2b72d
commit
c078772e59
@ -32,6 +32,7 @@ torchrun --nproc_per_node 8 ./zipformer_llm_zh/train.py \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
@ -625,6 +626,12 @@ def train_one_epoch(
|
||||
The rank of the node in DDP training. If no DDP is used, it should
|
||||
be set to 0.
|
||||
"""
|
||||
|
||||
def free_gpu_cache():
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model.train()
|
||||
model.encoder.eval()
|
||||
if not params.unfreeze_llm:
|
||||
@ -688,9 +695,6 @@ def train_one_epoch(
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
|
||||
@ -700,9 +704,26 @@ def train_one_epoch(
|
||||
model.backward(loss)
|
||||
model.step()
|
||||
|
||||
except: # noqa
|
||||
display_and_save_batch(batch, params=params)
|
||||
raise
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Caught exception: {e}")
|
||||
if (
|
||||
"CUDA" not in str(e)
|
||||
and "cuDNN error" not in str(e)
|
||||
and "NCCL error" not in str(e)
|
||||
):
|
||||
display_and_save_batch(batch, params=params)
|
||||
raise e
|
||||
|
||||
try:
|
||||
loss = None
|
||||
loss_info = None
|
||||
except:
|
||||
pass
|
||||
|
||||
free_gpu_cache()
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
try:
|
||||
|
Loading…
x
Reference in New Issue
Block a user