From a6f7814019e6dc9ae336a3d59f209f2638ab0c7f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 10 May 2022 15:55:33 +0800 Subject: [PATCH] Add diagnostics code. --- egs/librispeech/ASR/transducer_lstm/train.py | 69 +++++++++++++++++++- 1 file changed, 66 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index 3468b20fb..2d520f230 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -57,8 +57,8 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from encoder import LstmEncoder from decoder import Decoder +from encoder import LstmEncoder from joiner import Joiner from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler @@ -76,7 +76,15 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + measure_gradient_norms, + measure_weight_norms, + optim_step_and_measure_param_change, + setup_logger, + str2bool, +) LRSchedulerType = Union[ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler @@ -231,6 +239,14 @@ def get_parser(): help="Accumulate stats on activations, print them and exit.", ) + parser.add_argument( + "--log-diagnostics", + type=str2bool, + default=False, + help="True to also log parameter norm and " + "gradient norm to tensorboard.", + ) + parser.add_argument( "--save-every-n", type=int, @@ -318,7 +334,7 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 50, "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 + "valid_interval": 3000, # For the 100h subset, use 1600 # parameters for encoder "feature_dim": 80, "subsampling_factor": 4, @@ -645,6 +661,30 @@ def train_one_epoch( tot_loss = MetricsTracker() + def maybe_log_gradients(tag: str): + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + tb_writer.add_scalars( + tag, + measure_gradient_norms(model, norm="l2"), + global_step=params.batch_idx_train, + ) + + def maybe_log_weights(tag: str): + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + tb_writer.add_scalars( + tag, + measure_weight_norms(model, norm="l2"), + global_step=params.batch_idx_train, + ) + cur_batch_idx = params.get("cur_batch_idx", 0) for batch_idx, batch in enumerate(train_dl): @@ -669,9 +709,32 @@ def train_one_epoch( # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() + + maybe_log_weights("train/param_norms") + maybe_log_gradients("train/grad_norms") + + old_parameters = None + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + old_parameters = { + n: p.detach().clone() for n, p in model.named_parameters() + } + scheduler.step_batch(params.batch_idx_train) scaler.step(optimizer) scaler.update() + + if old_parameters is not None: + deltas = optim_step_and_measure_param_change(model, old_parameters) + tb_writer.add_scalars( + "train/relative_param_change_per_minibatch", + deltas, + global_step=params.batch_idx_train, + ) + optimizer.zero_grad() if params.print_diagnostics and batch_idx == 5: