add batch shave mechanism

fix

fix
This commit is contained in:
yfyeung 2025-05-12 16:49:42 +00:00
parent ea20ac208d
commit 06667e1f6d

View File

@ -362,6 +362,16 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--shave-rate",
type=float,
default=0.1,
help="""The factor to reduce the batch when an OOM occurs.
If OOM persists for the same batch, this factor will be
progressively multiplied by 1.5. Set to 0 to disable.
""",
)
parser.add_argument( parser.add_argument(
"--use-aishell", "--use-aishell",
type=str2bool, type=str2bool,
@ -627,6 +637,17 @@ def train_one_epoch(
be set to 0. be set to 0.
""" """
def shave_batch(batch: dict, factor: float):
n_utt = len(batch["supervisions"]["text"])
skip_point = max(1, int(factor * n_utt))
if n_utt - skip_point <= 0:
return False
for key in batch["supervisions"].keys():
batch["supervisions"][key] = batch["supervisions"][key][skip_point:]
max_len = max(batch["supervisions"]["num_frames"]).item()
batch["inputs"] = batch["inputs"][skip_point:, :max_len]
return True
def free_gpu_cache(): def free_gpu_cache():
gc.collect() gc.collect()
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -642,6 +663,7 @@ def train_one_epoch(
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
if batch_idx % params.valid_interval == 0: if batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
valid_info = compute_validation_loss( valid_info = compute_validation_loss(
@ -663,6 +685,9 @@ def train_one_epoch(
valid_info.write_summary( valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
) )
shave_rate = params.shave_rate
while True:
try: try:
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
@ -671,21 +696,21 @@ def train_one_epoch(
batch=batch, batch=batch,
is_training=True, is_training=True,
) )
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
# NOTE: we use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
# deepspeed's backward() is different from torch's backward() # deepspeed's backward() is different from torch's backward()
# in that it does not accept a loss tensor as input.
# It computes the loss internally.
model.backward(loss) model.backward(loss)
model.step() model.step()
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# finish this step
break
except Exception as e: except Exception as e:
logging.warning(f"Caught exception: {e}") logging.warning(f"Caught exception: {e}")
if ( if shave_rate <= 0 or (
"CUDA" not in str(e) "CUDA" not in str(e)
and "cuDNN error" not in str(e) and "cuDNN error" not in str(e)
and "NCCL error" not in str(e) and "NCCL error" not in str(e)
@ -693,14 +718,23 @@ def train_one_epoch(
display_and_save_batch(batch, params=params) display_and_save_batch(batch, params=params)
raise e raise e
try:
loss = None loss = None
loss_info = None loss_info = None
except:
pass
free_gpu_cache() free_gpu_cache()
if shave_batch(batch, shave_rate):
logging.warning(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}: {shave_rate * 100:.2f}% batch reduced",
)
shave_rate = min(shave_rate * 1.5, 0.5)
else:
raise RuntimeError(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}: batch reduced to empty in retry"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
try: try:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]