mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Save batch to disk on OOM. (#343)
* Save batch to disk on OOM. * minor fixes * Fixes after review. * Fix style issues.
This commit is contained in:
parent
9ddbc681e7
commit
e1c3e98980
@ -156,15 +156,16 @@ def get_parser():
|
|||||||
"--initial-lr",
|
"--initial-lr",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.003,
|
default=0.003,
|
||||||
help="The initial learning rate. This value should not need to be changed.",
|
help="The initial learning rate. This value should not need to "
|
||||||
|
"be changed.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lr-batches",
|
"--lr-batches",
|
||||||
type=float,
|
type=float,
|
||||||
default=5000,
|
default=5000,
|
||||||
help="""Number of steps that affects how rapidly the learning rate decreases.
|
help="""Number of steps that affects how rapidly the learning rate
|
||||||
We suggest not to change this.""",
|
decreases. We suggest not to change this.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -670,6 +671,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -689,6 +691,9 @@ def train_one_epoch(
|
|||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
except: # noqa
|
||||||
|
display_and_save_batch(batch, params=params, sp=sp)
|
||||||
|
raise
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
@ -933,6 +938,38 @@ def run(rank, world_size, args):
|
|||||||
cleanup_dist()
|
cleanup_dist()
|
||||||
|
|
||||||
|
|
||||||
|
def display_and_save_batch(
|
||||||
|
batch: dict,
|
||||||
|
params: AttributeDict,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
) -> None:
|
||||||
|
"""Display the batch statistics and save the batch into disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch:
|
||||||
|
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||||
|
for the content in it.
|
||||||
|
params:
|
||||||
|
Parameters for training. See :func:`get_params`.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
"""
|
||||||
|
from lhotse.utils import uuid4
|
||||||
|
|
||||||
|
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
|
||||||
|
logging.info(f"Saving batch to {filename}")
|
||||||
|
torch.save(batch, filename)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
features = batch["inputs"]
|
||||||
|
|
||||||
|
logging.info(f"features shape: {features.shape}")
|
||||||
|
|
||||||
|
y = sp.encode(supervisions["text"], out_type=int)
|
||||||
|
num_tokens = sum(len(i) for i in y)
|
||||||
|
logging.info(f"num tokens: {num_tokens}")
|
||||||
|
|
||||||
|
|
||||||
def scan_pessimistic_batches_for_oom(
|
def scan_pessimistic_batches_for_oom(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
@ -973,6 +1010,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
f"Failing criterion: {criterion} "
|
f"Failing criterion: {criterion} "
|
||||||
f"(={crit_values[criterion]}) ..."
|
f"(={crit_values[criterion]}) ..."
|
||||||
)
|
)
|
||||||
|
display_and_save_batch(batch, params=params, sp=sp)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user