Add diagnostics code.

This commit is contained in:
Fangjun Kuang 2022-05-10 15:55:33 +08:00
parent 1c9936898b
commit a6f7814019

View File

@ -57,8 +57,8 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from encoder import LstmEncoder
from decoder import Decoder from decoder import Decoder
from encoder import LstmEncoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
@ -76,7 +76,15 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
measure_gradient_norms,
measure_weight_norms,
optim_step_and_measure_param_change,
setup_logger,
str2bool,
)
LRSchedulerType = Union[ LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
@ -231,6 +239,14 @@ def get_parser():
help="Accumulate stats on activations, print them and exit.", help="Accumulate stats on activations, print them and exit.",
) )
parser.add_argument(
"--log-diagnostics",
type=str2bool,
default=False,
help="True to also log parameter norm and "
"gradient norm to tensorboard.",
)
parser.add_argument( parser.add_argument(
"--save-every-n", "--save-every-n",
type=int, type=int,
@ -318,7 +334,7 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800 "valid_interval": 3000, # For the 100h subset, use 1600
# parameters for encoder # parameters for encoder
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
@ -645,6 +661,30 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
def maybe_log_gradients(tag: str):
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
tb_writer.add_scalars(
tag,
measure_gradient_norms(model, norm="l2"),
global_step=params.batch_idx_train,
)
def maybe_log_weights(tag: str):
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
tb_writer.add_scalars(
tag,
measure_weight_norms(model, norm="l2"),
global_step=params.batch_idx_train,
)
cur_batch_idx = params.get("cur_batch_idx", 0) cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
@ -669,9 +709,32 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
scaler.scale(loss).backward() scaler.scale(loss).backward()
maybe_log_weights("train/param_norms")
maybe_log_gradients("train/grad_norms")
old_parameters = None
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
old_parameters = {
n: p.detach().clone() for n, p in model.named_parameters()
}
scheduler.step_batch(params.batch_idx_train) scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
if old_parameters is not None:
deltas = optim_step_and_measure_param_change(model, old_parameters)
tb_writer.add_scalars(
"train/relative_param_change_per_minibatch",
deltas,
global_step=params.batch_idx_train,
)
optimizer.zero_grad() optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5: