mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
add batch shave mechanism
fix fix
This commit is contained in:
parent
ea20ac208d
commit
06667e1f6d
@ -362,6 +362,16 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--shave-rate",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="""The factor to reduce the batch when an OOM occurs.
|
||||
If OOM persists for the same batch, this factor will be
|
||||
progressively multiplied by 1.5. Set to 0 to disable.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-aishell",
|
||||
type=str2bool,
|
||||
@ -627,6 +637,17 @@ def train_one_epoch(
|
||||
be set to 0.
|
||||
"""
|
||||
|
||||
def shave_batch(batch: dict, factor: float):
|
||||
n_utt = len(batch["supervisions"]["text"])
|
||||
skip_point = max(1, int(factor * n_utt))
|
||||
if n_utt - skip_point <= 0:
|
||||
return False
|
||||
for key in batch["supervisions"].keys():
|
||||
batch["supervisions"][key] = batch["supervisions"][key][skip_point:]
|
||||
max_len = max(batch["supervisions"]["num_frames"]).item()
|
||||
batch["inputs"] = batch["inputs"][skip_point:, :max_len]
|
||||
return True
|
||||
|
||||
def free_gpu_cache():
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
@ -642,6 +663,7 @@ def train_one_epoch(
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
if batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
@ -663,44 +685,56 @@ def train_one_epoch(
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
try:
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
|
||||
# deepspeed's backward() is different from torch's backward()
|
||||
# in that it does not accept a loss tensor as input.
|
||||
# It computes the loss internally.
|
||||
model.backward(loss)
|
||||
model.step()
|
||||
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Caught exception: {e}")
|
||||
if (
|
||||
"CUDA" not in str(e)
|
||||
and "cuDNN error" not in str(e)
|
||||
and "NCCL error" not in str(e)
|
||||
):
|
||||
display_and_save_batch(batch, params=params)
|
||||
raise e
|
||||
|
||||
shave_rate = params.shave_rate
|
||||
while True:
|
||||
try:
|
||||
loss = None
|
||||
loss_info = None
|
||||
except:
|
||||
pass
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
|
||||
# NOTE: we use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
# deepspeed's backward() is different from torch's backward()
|
||||
model.backward(loss)
|
||||
model.step()
|
||||
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
# finish this step
|
||||
break
|
||||
except Exception as e:
|
||||
logging.warning(f"Caught exception: {e}")
|
||||
if shave_rate <= 0 or (
|
||||
"CUDA" not in str(e)
|
||||
and "cuDNN error" not in str(e)
|
||||
and "NCCL error" not in str(e)
|
||||
):
|
||||
display_and_save_batch(batch, params=params)
|
||||
raise e
|
||||
|
||||
loss = None
|
||||
loss_info = None
|
||||
|
||||
free_gpu_cache()
|
||||
|
||||
if shave_batch(batch, shave_rate):
|
||||
logging.warning(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}: {shave_rate * 100:.2f}% batch reduced",
|
||||
)
|
||||
shave_rate = min(shave_rate * 1.5, 0.5)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}: batch reduced to empty in retry"
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
try:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
Loading…
x
Reference in New Issue
Block a user