Add function display_and_save_batch in wenetspeech/pruned_transducer_stateless2/train.py (#528)

* Add function display_and_save_batch in egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py

* Modify function: display_and_save_batch

* Delete empty line in pruned_transducer_stateless2/train.py

* Modify code format
This commit is contained in:
yangsuxia 2022-08-13 11:09:54 +08:00 committed by GitHub
parent 5c17255eec
commit 951b03f6d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -701,6 +701,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
@ -720,6 +721,9 @@ def train_one_epoch(
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params)
raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
return return
@ -958,6 +962,35 @@ def run(rank, world_size, args):
cleanup_dist() cleanup_dist()
def display_and_save_batch(
batch: dict,
params: AttributeDict,
) -> None:
"""Display the batch statistics and save the batch into disk.
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.
"""
from lhotse.utils import uuid4
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
texts = batch["supervisions"]["text"]
num_tokens = sum(len(i) for i in texts)
logging.info(f"num tokens: {num_tokens}")
def scan_pessimistic_batches_for_oom( def scan_pessimistic_batches_for_oom(
model: nn.Module, model: nn.Module,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
@ -998,6 +1031,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch(batch, params=params)
raise raise