mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
add batch shave mechanism
fix fix
This commit is contained in:
parent
ea20ac208d
commit
06667e1f6d
@ -362,6 +362,16 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--shave-rate",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="""The factor to reduce the batch when an OOM occurs.
|
||||||
|
If OOM persists for the same batch, this factor will be
|
||||||
|
progressively multiplied by 1.5. Set to 0 to disable.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-aishell",
|
"--use-aishell",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -627,6 +637,17 @@ def train_one_epoch(
|
|||||||
be set to 0.
|
be set to 0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def shave_batch(batch: dict, factor: float):
|
||||||
|
n_utt = len(batch["supervisions"]["text"])
|
||||||
|
skip_point = max(1, int(factor * n_utt))
|
||||||
|
if n_utt - skip_point <= 0:
|
||||||
|
return False
|
||||||
|
for key in batch["supervisions"].keys():
|
||||||
|
batch["supervisions"][key] = batch["supervisions"][key][skip_point:]
|
||||||
|
max_len = max(batch["supervisions"]["num_frames"]).item()
|
||||||
|
batch["inputs"] = batch["inputs"][skip_point:, :max_len]
|
||||||
|
return True
|
||||||
|
|
||||||
def free_gpu_cache():
|
def free_gpu_cache():
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -642,6 +663,7 @@ def train_one_epoch(
|
|||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
if batch_idx % params.valid_interval == 0:
|
if batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
@ -663,44 +685,56 @@ def train_one_epoch(
|
|||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
loss, loss_info = compute_loss(
|
|
||||||
params=params,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
model=model,
|
|
||||||
batch=batch,
|
|
||||||
is_training=True,
|
|
||||||
)
|
|
||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
|
||||||
# in the batch and there is no normalization to it so far.
|
|
||||||
|
|
||||||
# deepspeed's backward() is different from torch's backward()
|
|
||||||
# in that it does not accept a loss tensor as input.
|
|
||||||
# It computes the loss internally.
|
|
||||||
model.backward(loss)
|
|
||||||
model.step()
|
|
||||||
|
|
||||||
# summary stats
|
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f"Caught exception: {e}")
|
|
||||||
if (
|
|
||||||
"CUDA" not in str(e)
|
|
||||||
and "cuDNN error" not in str(e)
|
|
||||||
and "NCCL error" not in str(e)
|
|
||||||
):
|
|
||||||
display_and_save_batch(batch, params=params)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
shave_rate = params.shave_rate
|
||||||
|
while True:
|
||||||
try:
|
try:
|
||||||
loss = None
|
loss, loss_info = compute_loss(
|
||||||
loss_info = None
|
params=params,
|
||||||
except:
|
tokenizer=tokenizer,
|
||||||
pass
|
model=model,
|
||||||
|
batch=batch,
|
||||||
|
is_training=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: we use reduction==sum and loss is computed over utterances
|
||||||
|
# in the batch and there is no normalization to it so far.
|
||||||
|
# deepspeed's backward() is different from torch's backward()
|
||||||
|
model.backward(loss)
|
||||||
|
model.step()
|
||||||
|
|
||||||
|
# summary stats
|
||||||
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
|
|
||||||
|
# finish this step
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Caught exception: {e}")
|
||||||
|
if shave_rate <= 0 or (
|
||||||
|
"CUDA" not in str(e)
|
||||||
|
and "cuDNN error" not in str(e)
|
||||||
|
and "NCCL error" not in str(e)
|
||||||
|
):
|
||||||
|
display_and_save_batch(batch, params=params)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
loss_info = None
|
||||||
|
|
||||||
free_gpu_cache()
|
free_gpu_cache()
|
||||||
|
|
||||||
|
if shave_batch(batch, shave_rate):
|
||||||
|
logging.warning(
|
||||||
|
f"Epoch {params.cur_epoch}, "
|
||||||
|
f"batch {batch_idx}: {shave_rate * 100:.2f}% batch reduced",
|
||||||
|
)
|
||||||
|
shave_rate = min(shave_rate * 1.5, 0.5)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Epoch {params.cur_epoch}, "
|
||||||
|
f"batch {batch_idx}: batch reduced to empty in retry"
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
try:
|
try:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user