From 951b03f6d74ba11b9e79143685fc42fc6e0a7805 Mon Sep 17 00:00:00 2001 From: yangsuxia <34536059+yangsuxia@users.noreply.github.com> Date: Sat, 13 Aug 2022 11:09:54 +0800 Subject: [PATCH] Add function display_and_save_batch in wenetspeech/pruned_transducer_stateless2/train.py (#528) * Add function display_and_save_batch in egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py * Modify function: display_and_save_batch * Delete empty line in pruned_transducer_stateless2/train.py * Modify code format --- .../ASR/pruned_transducer_stateless2/train.py | 70 ++++++++++++++----- 1 file changed, 52 insertions(+), 18 deletions(-) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index faf25eda1..5208dbefe 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -701,25 +701,29 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params) + raise if params.print_diagnostics and batch_idx == 5: return @@ -958,6 +962,35 @@ def run(rank, world_size, args): cleanup_dist() +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + texts = batch["supervisions"]["text"] + num_tokens = sum(len(i) for i in texts) + + logging.info(f"num tokens: {num_tokens}") + + def scan_pessimistic_batches_for_oom( model: nn.Module, train_dl: torch.utils.data.DataLoader, @@ -998,6 +1031,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) + display_and_save_batch(batch, params=params) raise