mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
Add diagnostics code.
This commit is contained in:
parent
1c9936898b
commit
a6f7814019
@ -57,8 +57,8 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from encoder import LstmEncoder
|
||||
from decoder import Decoder
|
||||
from encoder import LstmEncoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
@ -76,7 +76,15 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
measure_gradient_norms,
|
||||
measure_weight_norms,
|
||||
optim_step_and_measure_param_change,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
@ -231,6 +239,14 @@ def get_parser():
|
||||
help="Accumulate stats on activations, print them and exit.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-diagnostics",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="True to also log parameter norm and "
|
||||
"gradient norm to tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--save-every-n",
|
||||
type=int,
|
||||
@ -318,7 +334,7 @@ def get_params() -> AttributeDict:
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000, # For the 100h subset, use 800
|
||||
"valid_interval": 3000, # For the 100h subset, use 1600
|
||||
# parameters for encoder
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
@ -645,6 +661,30 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
def maybe_log_gradients(tag: str):
|
||||
if (
|
||||
params.log_diagnostics
|
||||
and tb_writer is not None
|
||||
and params.batch_idx_train % (params.log_interval * 5) == 0
|
||||
):
|
||||
tb_writer.add_scalars(
|
||||
tag,
|
||||
measure_gradient_norms(model, norm="l2"),
|
||||
global_step=params.batch_idx_train,
|
||||
)
|
||||
|
||||
def maybe_log_weights(tag: str):
|
||||
if (
|
||||
params.log_diagnostics
|
||||
and tb_writer is not None
|
||||
and params.batch_idx_train % (params.log_interval * 5) == 0
|
||||
):
|
||||
tb_writer.add_scalars(
|
||||
tag,
|
||||
measure_weight_norms(model, norm="l2"),
|
||||
global_step=params.batch_idx_train,
|
||||
)
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
@ -669,9 +709,32 @@ def train_one_epoch(
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
maybe_log_weights("train/param_norms")
|
||||
maybe_log_gradients("train/grad_norms")
|
||||
|
||||
old_parameters = None
|
||||
if (
|
||||
params.log_diagnostics
|
||||
and tb_writer is not None
|
||||
and params.batch_idx_train % (params.log_interval * 5) == 0
|
||||
):
|
||||
old_parameters = {
|
||||
n: p.detach().clone() for n, p in model.named_parameters()
|
||||
}
|
||||
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
if old_parameters is not None:
|
||||
deltas = optim_step_and_measure_param_change(model, old_parameters)
|
||||
tb_writer.add_scalars(
|
||||
"train/relative_param_change_per_minibatch",
|
||||
deltas,
|
||||
global_step=params.batch_idx_train,
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
|
Loading…
x
Reference in New Issue
Block a user