From 06667e1f6d8eb45d30e883988e61106ae55e97ad Mon Sep 17 00:00:00 2001 From: yfyeung Date: Mon, 12 May 2025 16:49:42 +0000 Subject: [PATCH] add batch shave mechanism fix fix --- .../ASR_LLM/zipformer_llm_zh/train.py | 100 ++++++++++++------ 1 file changed, 67 insertions(+), 33 deletions(-) diff --git a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py index eaae4a33e..441b4f266 100755 --- a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py @@ -362,6 +362,16 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--shave-rate", + type=float, + default=0.1, + help="""The factor to reduce the batch when an OOM occurs. + If OOM persists for the same batch, this factor will be + progressively multiplied by 1.5. Set to 0 to disable. + """, + ) + parser.add_argument( "--use-aishell", type=str2bool, @@ -627,6 +637,17 @@ def train_one_epoch( be set to 0. """ + def shave_batch(batch: dict, factor: float): + n_utt = len(batch["supervisions"]["text"]) + skip_point = max(1, int(factor * n_utt)) + if n_utt - skip_point <= 0: + return False + for key in batch["supervisions"].keys(): + batch["supervisions"][key] = batch["supervisions"][key][skip_point:] + max_len = max(batch["supervisions"]["num_frames"]).item() + batch["inputs"] = batch["inputs"][skip_point:, :max_len] + return True + def free_gpu_cache(): gc.collect() if torch.cuda.is_available(): @@ -642,6 +663,7 @@ def train_one_epoch( for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) + if batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") valid_info = compute_validation_loss( @@ -663,44 +685,56 @@ def train_one_epoch( valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) - try: - loss, loss_info = compute_loss( - params=params, - tokenizer=tokenizer, - model=model, - batch=batch, - is_training=True, - ) - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - - # deepspeed's backward() is different from torch's backward() - # in that it does not accept a loss tensor as input. - # It computes the loss internally. - model.backward(loss) - model.step() - - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - except Exception as e: - logging.warning(f"Caught exception: {e}") - if ( - "CUDA" not in str(e) - and "cuDNN error" not in str(e) - and "NCCL error" not in str(e) - ): - display_and_save_batch(batch, params=params) - raise e + shave_rate = params.shave_rate + while True: try: - loss = None - loss_info = None - except: - pass + loss, loss_info = compute_loss( + params=params, + tokenizer=tokenizer, + model=model, + batch=batch, + is_training=True, + ) + + # NOTE: we use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + # deepspeed's backward() is different from torch's backward() + model.backward(loss) + model.step() + + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # finish this step + break + except Exception as e: + logging.warning(f"Caught exception: {e}") + if shave_rate <= 0 or ( + "CUDA" not in str(e) + and "cuDNN error" not in str(e) + and "NCCL error" not in str(e) + ): + display_and_save_batch(batch, params=params) + raise e + + loss = None + loss_info = None free_gpu_cache() + if shave_batch(batch, shave_rate): + logging.warning( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}: {shave_rate * 100:.2f}% batch reduced", + ) + shave_rate = min(shave_rate * 1.5, 0.5) + else: + raise RuntimeError( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}: batch reduced to empty in retry" + ) + if batch_idx % params.log_interval == 0: try: cur_lr = scheduler.get_last_lr()[0]