mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Add diagnostics code.
This commit is contained in:
parent
1c9936898b
commit
a6f7814019
@ -57,8 +57,8 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from encoder import LstmEncoder
|
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
|
from encoder import LstmEncoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
@ -76,7 +76,15 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
|||||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
measure_gradient_norms,
|
||||||
|
measure_weight_norms,
|
||||||
|
optim_step_and_measure_param_change,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[
|
LRSchedulerType = Union[
|
||||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||||
@ -231,6 +239,14 @@ def get_parser():
|
|||||||
help="Accumulate stats on activations, print them and exit.",
|
help="Accumulate stats on activations, print them and exit.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-diagnostics",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="True to also log parameter norm and "
|
||||||
|
"gradient norm to tensorboard.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-every-n",
|
"--save-every-n",
|
||||||
type=int,
|
type=int,
|
||||||
@ -318,7 +334,7 @@ def get_params() -> AttributeDict:
|
|||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000, # For the 100h subset, use 800
|
"valid_interval": 3000, # For the 100h subset, use 1600
|
||||||
# parameters for encoder
|
# parameters for encoder
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
@ -645,6 +661,30 @@ def train_one_epoch(
|
|||||||
|
|
||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
|
def maybe_log_gradients(tag: str):
|
||||||
|
if (
|
||||||
|
params.log_diagnostics
|
||||||
|
and tb_writer is not None
|
||||||
|
and params.batch_idx_train % (params.log_interval * 5) == 0
|
||||||
|
):
|
||||||
|
tb_writer.add_scalars(
|
||||||
|
tag,
|
||||||
|
measure_gradient_norms(model, norm="l2"),
|
||||||
|
global_step=params.batch_idx_train,
|
||||||
|
)
|
||||||
|
|
||||||
|
def maybe_log_weights(tag: str):
|
||||||
|
if (
|
||||||
|
params.log_diagnostics
|
||||||
|
and tb_writer is not None
|
||||||
|
and params.batch_idx_train % (params.log_interval * 5) == 0
|
||||||
|
):
|
||||||
|
tb_writer.add_scalars(
|
||||||
|
tag,
|
||||||
|
measure_weight_norms(model, norm="l2"),
|
||||||
|
global_step=params.batch_idx_train,
|
||||||
|
)
|
||||||
|
|
||||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
@ -669,9 +709,32 @@ def train_one_epoch(
|
|||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
maybe_log_weights("train/param_norms")
|
||||||
|
maybe_log_gradients("train/grad_norms")
|
||||||
|
|
||||||
|
old_parameters = None
|
||||||
|
if (
|
||||||
|
params.log_diagnostics
|
||||||
|
and tb_writer is not None
|
||||||
|
and params.batch_idx_train % (params.log_interval * 5) == 0
|
||||||
|
):
|
||||||
|
old_parameters = {
|
||||||
|
n: p.detach().clone() for n, p in model.named_parameters()
|
||||||
|
}
|
||||||
|
|
||||||
scheduler.step_batch(params.batch_idx_train)
|
scheduler.step_batch(params.batch_idx_train)
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
|
|
||||||
|
if old_parameters is not None:
|
||||||
|
deltas = optim_step_and_measure_param_change(model, old_parameters)
|
||||||
|
tb_writer.add_scalars(
|
||||||
|
"train/relative_param_change_per_minibatch",
|
||||||
|
deltas,
|
||||||
|
global_step=params.batch_idx_train,
|
||||||
|
)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user