This commit is contained in:
yfyeung 2025-05-11 17:23:19 +00:00
parent 9939c2b72d
commit c078772e59

View File

@ -32,6 +32,7 @@ torchrun --nproc_per_node 8 ./zipformer_llm_zh/train.py \
""" """
import argparse import argparse
import gc
import logging import logging
import os import os
import warnings import warnings
@ -625,6 +626,12 @@ def train_one_epoch(
The rank of the node in DDP training. If no DDP is used, it should The rank of the node in DDP training. If no DDP is used, it should
be set to 0. be set to 0.
""" """
def free_gpu_cache():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
model.train() model.train()
model.encoder.eval() model.encoder.eval()
if not params.unfreeze_llm: if not params.unfreeze_llm:
@ -688,9 +695,6 @@ def train_one_epoch(
batch=batch, batch=batch,
is_training=True, is_training=True,
) )
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
@ -700,9 +704,26 @@ def train_one_epoch(
model.backward(loss) model.backward(loss)
model.step() model.step()
except: # noqa # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
except Exception as e:
logging.warning(f"Caught exception: {e}")
if (
"CUDA" not in str(e)
and "cuDNN error" not in str(e)
and "NCCL error" not in str(e)
):
display_and_save_batch(batch, params=params) display_and_save_batch(batch, params=params)
raise raise e
try:
loss = None
loss_info = None
except:
pass
free_gpu_cache()
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
try: try: