This commit is contained in:
yfyeung 2025-05-11 17:23:19 +00:00
parent 9939c2b72d
commit c078772e59

View File

@ -32,6 +32,7 @@ torchrun --nproc_per_node 8 ./zipformer_llm_zh/train.py \
"""
import argparse
import gc
import logging
import os
import warnings
@ -625,6 +626,12 @@ def train_one_epoch(
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
def free_gpu_cache():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
model.train()
model.encoder.eval()
if not params.unfreeze_llm:
@ -688,9 +695,6 @@ def train_one_epoch(
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
@ -700,9 +704,26 @@ def train_one_epoch(
model.backward(loss)
model.step()
except: # noqa
display_and_save_batch(batch, params=params)
raise
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
except Exception as e:
logging.warning(f"Caught exception: {e}")
if (
"CUDA" not in str(e)
and "cuDNN error" not in str(e)
and "NCCL error" not in str(e)
):
display_and_save_batch(batch, params=params)
raise e
try:
loss = None
loss_info = None
except:
pass
free_gpu_cache()
if batch_idx % params.log_interval == 0:
try: