add batch shave mechanism

fix

fix
This commit is contained in:
yfyeung 2025-05-12 16:49:42 +00:00
parent ea20ac208d
commit 06667e1f6d

View File

@ -362,6 +362,16 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--shave-rate",
type=float,
default=0.1,
help="""The factor to reduce the batch when an OOM occurs.
If OOM persists for the same batch, this factor will be
progressively multiplied by 1.5. Set to 0 to disable.
""",
)
parser.add_argument(
"--use-aishell",
type=str2bool,
@ -627,6 +637,17 @@ def train_one_epoch(
be set to 0.
"""
def shave_batch(batch: dict, factor: float):
n_utt = len(batch["supervisions"]["text"])
skip_point = max(1, int(factor * n_utt))
if n_utt - skip_point <= 0:
return False
for key in batch["supervisions"].keys():
batch["supervisions"][key] = batch["supervisions"][key][skip_point:]
max_len = max(batch["supervisions"]["num_frames"]).item()
batch["inputs"] = batch["inputs"][skip_point:, :max_len]
return True
def free_gpu_cache():
gc.collect()
if torch.cuda.is_available():
@ -642,6 +663,7 @@ def train_one_epoch(
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
if batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
@ -663,44 +685,56 @@ def train_one_epoch(
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
try:
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,
model=model,
batch=batch,
is_training=True,
)
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
# deepspeed's backward() is different from torch's backward()
# in that it does not accept a loss tensor as input.
# It computes the loss internally.
model.backward(loss)
model.step()
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
except Exception as e:
logging.warning(f"Caught exception: {e}")
if (
"CUDA" not in str(e)
and "cuDNN error" not in str(e)
and "NCCL error" not in str(e)
):
display_and_save_batch(batch, params=params)
raise e
shave_rate = params.shave_rate
while True:
try:
loss = None
loss_info = None
except:
pass
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,
model=model,
batch=batch,
is_training=True,
)
# NOTE: we use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
# deepspeed's backward() is different from torch's backward()
model.backward(loss)
model.step()
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# finish this step
break
except Exception as e:
logging.warning(f"Caught exception: {e}")
if shave_rate <= 0 or (
"CUDA" not in str(e)
and "cuDNN error" not in str(e)
and "NCCL error" not in str(e)
):
display_and_save_batch(batch, params=params)
raise e
loss = None
loss_info = None
free_gpu_cache()
if shave_batch(batch, shave_rate):
logging.warning(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}: {shave_rate * 100:.2f}% batch reduced",
)
shave_rate = min(shave_rate * 1.5, 0.5)
else:
raise RuntimeError(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}: batch reduced to empty in retry"
)
if batch_idx % params.log_interval == 0:
try:
cur_lr = scheduler.get_last_lr()[0]