mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
skip OOM
This commit is contained in:
parent
9939c2b72d
commit
c078772e59
@ -32,6 +32,7 @@ torchrun --nproc_per_node 8 ./zipformer_llm_zh/train.py \
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
@ -625,6 +626,12 @@ def train_one_epoch(
|
|||||||
The rank of the node in DDP training. If no DDP is used, it should
|
The rank of the node in DDP training. If no DDP is used, it should
|
||||||
be set to 0.
|
be set to 0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def free_gpu_cache():
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
model.encoder.eval()
|
model.encoder.eval()
|
||||||
if not params.unfreeze_llm:
|
if not params.unfreeze_llm:
|
||||||
@ -688,9 +695,6 @@ def train_one_epoch(
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
)
|
)
|
||||||
# summary stats
|
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
|
||||||
|
|
||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
|
|
||||||
@ -700,9 +704,26 @@ def train_one_epoch(
|
|||||||
model.backward(loss)
|
model.backward(loss)
|
||||||
model.step()
|
model.step()
|
||||||
|
|
||||||
except: # noqa
|
# summary stats
|
||||||
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Caught exception: {e}")
|
||||||
|
if (
|
||||||
|
"CUDA" not in str(e)
|
||||||
|
and "cuDNN error" not in str(e)
|
||||||
|
and "NCCL error" not in str(e)
|
||||||
|
):
|
||||||
display_and_save_batch(batch, params=params)
|
display_and_save_batch(batch, params=params)
|
||||||
raise
|
raise e
|
||||||
|
|
||||||
|
try:
|
||||||
|
loss = None
|
||||||
|
loss_info = None
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
free_gpu_cache()
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
try:
|
try:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user