mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Save batch to disk on OOM. (#343)
* Save batch to disk on OOM. * minor fixes * Fixes after review. * Fix style issues.
This commit is contained in:
parent
9ddbc681e7
commit
e1c3e98980
@ -156,15 +156,16 @@ def get_parser():
|
||||
"--initial-lr",
|
||||
type=float,
|
||||
default=0.003,
|
||||
help="The initial learning rate. This value should not need to be changed.",
|
||||
help="The initial learning rate. This value should not need to "
|
||||
"be changed.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-batches",
|
||||
type=float,
|
||||
default=5000,
|
||||
help="""Number of steps that affects how rapidly the learning rate decreases.
|
||||
We suggest not to change this.""",
|
||||
help="""Number of steps that affects how rapidly the learning rate
|
||||
decreases. We suggest not to change this.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -670,25 +671,29 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=(params.batch_idx_train / params.model_warm_step),
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=(params.batch_idx_train / params.model_warm_step),
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
scaler.scale(loss).backward()
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
scaler.scale(loss).backward()
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
except: # noqa
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
raise
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
@ -933,6 +938,38 @@ def run(rank, world_size, args):
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def display_and_save_batch(
|
||||
batch: dict,
|
||||
params: AttributeDict,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> None:
|
||||
"""Display the batch statistics and save the batch into disk.
|
||||
|
||||
Args:
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
sp:
|
||||
The BPE model.
|
||||
"""
|
||||
from lhotse.utils import uuid4
|
||||
|
||||
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
|
||||
logging.info(f"Saving batch to {filename}")
|
||||
torch.save(batch, filename)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
features = batch["inputs"]
|
||||
|
||||
logging.info(f"features shape: {features.shape}")
|
||||
|
||||
y = sp.encode(supervisions["text"], out_type=int)
|
||||
num_tokens = sum(len(i) for i in y)
|
||||
logging.info(f"num tokens: {num_tokens}")
|
||||
|
||||
|
||||
def scan_pessimistic_batches_for_oom(
|
||||
model: nn.Module,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
@ -973,6 +1010,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
f"Failing criterion: {criterion} "
|
||||
f"(={crit_values[criterion]}) ..."
|
||||
)
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
raise
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user