From c078772e59e797754ec7c4e891f1f17aa2c82316 Mon Sep 17 00:00:00 2001 From: yfyeung Date: Sun, 11 May 2025 17:23:19 +0000 Subject: [PATCH] skip OOM --- .../ASR_LLM/zipformer_llm_zh/train.py | 33 +++++++++++++++---- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py index 5d47f128a..2ace5c532 100755 --- a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py @@ -32,6 +32,7 @@ torchrun --nproc_per_node 8 ./zipformer_llm_zh/train.py \ """ import argparse +import gc import logging import os import warnings @@ -625,6 +626,12 @@ def train_one_epoch( The rank of the node in DDP training. If no DDP is used, it should be set to 0. """ + + def free_gpu_cache(): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + model.train() model.encoder.eval() if not params.unfreeze_llm: @@ -688,9 +695,6 @@ def train_one_epoch( batch=batch, is_training=True, ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. @@ -700,9 +704,26 @@ def train_one_epoch( model.backward(loss) model.step() - except: # noqa - display_and_save_batch(batch, params=params) - raise + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + except Exception as e: + logging.warning(f"Caught exception: {e}") + if ( + "CUDA" not in str(e) + and "cuDNN error" not in str(e) + and "NCCL error" not in str(e) + ): + display_and_save_batch(batch, params=params) + raise e + + try: + loss = None + loss_info = None + except: + pass + + free_gpu_cache() if batch_idx % params.log_interval == 0: try: