From 46b9be31cc485b7f458f9547c240684a90ee44d1 Mon Sep 17 00:00:00 2001 From: yfyeung Date: Mon, 28 Apr 2025 04:37:57 -0700 Subject: [PATCH] remove batch shaving --- egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 9e1646808..63893ef0b 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -71,7 +71,6 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, - filter_uneven_sized_batch, setup_logger, str2bool, ) @@ -278,13 +277,6 @@ def compute_loss( Returns: Return a tuple of two elements. The first element is the loss tensor. """ - # For the uneven-sized batch, the total duration after padding would possibly - # cause OOM. Hence, for each batch, which is sorted descendingly by length, - # we simply drop the last few shortest samples, so that the retained total frames - # (after padding) would not exceed `allowed_max_frames`: - # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, - # where `max_frames = max_duration * 1000 // frame_shift_ms`. - # We set allowed_excess_duration_ratio=0.1. def preprocess( messages, @@ -375,10 +367,6 @@ def compute_loss( text = text.replace("?", "") return text - max_frames = params.max_duration * 1000 // params.frame_shift_ms - allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) - batch = filter_uneven_sized_batch(batch, allowed_max_frames) - device = next(model.parameters()).device feature = batch["inputs"]