Add diagnostics code.

This commit is contained in:
Fangjun Kuang 2022-05-10 15:55:33 +08:00
parent 1c9936898b
commit a6f7814019

View File

@ -57,8 +57,8 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from encoder import LstmEncoder
from decoder import Decoder
from encoder import LstmEncoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
@ -76,7 +76,15 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
measure_gradient_norms,
measure_weight_norms,
optim_step_and_measure_param_change,
setup_logger,
str2bool,
)
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
@ -231,6 +239,14 @@ def get_parser():
help="Accumulate stats on activations, print them and exit.",
)
parser.add_argument(
"--log-diagnostics",
type=str2bool,
default=False,
help="True to also log parameter norm and "
"gradient norm to tensorboard.",
)
parser.add_argument(
"--save-every-n",
type=int,
@ -318,7 +334,7 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
"valid_interval": 3000, # For the 100h subset, use 1600
# parameters for encoder
"feature_dim": 80,
"subsampling_factor": 4,
@ -645,6 +661,30 @@ def train_one_epoch(
tot_loss = MetricsTracker()
def maybe_log_gradients(tag: str):
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
tb_writer.add_scalars(
tag,
measure_gradient_norms(model, norm="l2"),
global_step=params.batch_idx_train,
)
def maybe_log_weights(tag: str):
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
tb_writer.add_scalars(
tag,
measure_weight_norms(model, norm="l2"),
global_step=params.batch_idx_train,
)
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl):
@ -669,9 +709,32 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
maybe_log_weights("train/param_norms")
maybe_log_gradients("train/grad_norms")
old_parameters = None
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
old_parameters = {
n: p.detach().clone() for n, p in model.named_parameters()
}
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
if old_parameters is not None:
deltas = optim_step_and_measure_param_change(model, old_parameters)
tb_writer.add_scalars(
"train/relative_param_change_per_minibatch",
deltas,
global_step=params.batch_idx_train,
)
optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5: