mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Add function display_and_save_batch in wenetspeech/pruned_transducer_stateless2/train.py (#528)
* Add function display_and_save_batch in egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py * Modify function: display_and_save_batch * Delete empty line in pruned_transducer_stateless2/train.py * Modify code format
This commit is contained in:
parent
5c17255eec
commit
951b03f6d7
@ -701,25 +701,29 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
try:
|
||||||
loss, loss_info = compute_loss(
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
params=params,
|
loss, loss_info = compute_loss(
|
||||||
model=model,
|
params=params,
|
||||||
graph_compiler=graph_compiler,
|
model=model,
|
||||||
batch=batch,
|
graph_compiler=graph_compiler,
|
||||||
is_training=True,
|
batch=batch,
|
||||||
warmup=(params.batch_idx_train / params.model_warm_step),
|
is_training=True,
|
||||||
)
|
warmup=(params.batch_idx_train / params.model_warm_step),
|
||||||
# summary stats
|
)
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
# summary stats
|
||||||
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
|
|
||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
scheduler.step_batch(params.batch_idx_train)
|
scheduler.step_batch(params.batch_idx_train)
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
except: # noqa
|
||||||
|
display_and_save_batch(batch, params=params)
|
||||||
|
raise
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
@ -958,6 +962,35 @@ def run(rank, world_size, args):
|
|||||||
cleanup_dist()
|
cleanup_dist()
|
||||||
|
|
||||||
|
|
||||||
|
def display_and_save_batch(
|
||||||
|
batch: dict,
|
||||||
|
params: AttributeDict,
|
||||||
|
) -> None:
|
||||||
|
"""Display the batch statistics and save the batch into disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch:
|
||||||
|
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||||
|
for the content in it.
|
||||||
|
params:
|
||||||
|
Parameters for training. See :func:`get_params`.
|
||||||
|
"""
|
||||||
|
from lhotse.utils import uuid4
|
||||||
|
|
||||||
|
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
|
||||||
|
logging.info(f"Saving batch to {filename}")
|
||||||
|
torch.save(batch, filename)
|
||||||
|
|
||||||
|
features = batch["inputs"]
|
||||||
|
|
||||||
|
logging.info(f"features shape: {features.shape}")
|
||||||
|
|
||||||
|
texts = batch["supervisions"]["text"]
|
||||||
|
num_tokens = sum(len(i) for i in texts)
|
||||||
|
|
||||||
|
logging.info(f"num tokens: {num_tokens}")
|
||||||
|
|
||||||
|
|
||||||
def scan_pessimistic_batches_for_oom(
|
def scan_pessimistic_batches_for_oom(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
@ -998,6 +1031,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
f"Failing criterion: {criterion} "
|
f"Failing criterion: {criterion} "
|
||||||
f"(={crit_values[criterion]}) ..."
|
f"(={crit_values[criterion]}) ..."
|
||||||
)
|
)
|
||||||
|
display_and_save_batch(batch, params=params)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user