mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
parent
08473a17aa
commit
f721a2fd7a
@ -609,21 +609,6 @@ def train_one_epoch(
|
||||
global_step=params.batch_idx_train,
|
||||
)
|
||||
|
||||
def maybe_log_param_relative_changes():
|
||||
if (
|
||||
params.log_diagnostics
|
||||
and tb_writer is not None
|
||||
and params.batch_idx_train % (params.log_interval * 5) == 0
|
||||
):
|
||||
deltas = optim_step_and_measure_param_change(model, optimizer)
|
||||
tb_writer.add_scalars(
|
||||
"train/relative_param_change_per_minibatch",
|
||||
deltas,
|
||||
global_step=params.batch_idx_train,
|
||||
)
|
||||
else:
|
||||
optimizer.step()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
@ -651,7 +636,26 @@ def train_one_epoch(
|
||||
|
||||
maybe_log_weights("train/param_norms")
|
||||
maybe_log_gradients("train/grad_norms")
|
||||
maybe_log_param_relative_changes()
|
||||
|
||||
old_parameters = None
|
||||
if (
|
||||
params.log_diagnostics
|
||||
and tb_writer is not None
|
||||
and params.batch_idx_train % (params.log_interval * 5) == 0
|
||||
):
|
||||
old_parameters = {
|
||||
n: p.detach().clone() for n, p in model.named_parameters()
|
||||
}
|
||||
|
||||
optimizer.step()
|
||||
|
||||
if old_parameters is not None:
|
||||
deltas = optim_step_and_measure_param_change(model, old_parameters)
|
||||
tb_writer.add_scalars(
|
||||
"train/relative_param_change_per_minibatch",
|
||||
deltas,
|
||||
global_step=params.batch_idx_train,
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
@ -25,15 +25,14 @@ from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, TextIO, Optional, Tuple, Union
|
||||
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
||||
|
||||
import k2
|
||||
import k2.version
|
||||
import kaldialign
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from torch.cuda.amp import GradScaler
|
||||
import torch.nn as nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
Pathlike = Union[str, Path]
|
||||
@ -758,11 +757,10 @@ def measure_gradient_norms(
|
||||
|
||||
def optim_step_and_measure_param_change(
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
old_parameters: Dict[str, nn.parameter.Parameter],
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Perform model weight update and measure the "relative change in parameters per minibatch."
|
||||
Measure the "relative change in parameters per minibatch."
|
||||
It is understood as a ratio between the L2 norm of the difference between original and updates parameters,
|
||||
and the L2 norm of the original parameter. It is given by the formula:
|
||||
|
||||
@ -770,16 +768,31 @@ def optim_step_and_measure_param_change(
|
||||
\begin{aligned}
|
||||
\delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2}
|
||||
\end{aligned}
|
||||
"""
|
||||
param_copy = {n: p.detach().clone() for n, p in model.named_parameters()}
|
||||
if scaler:
|
||||
scaler.step(optimizer)
|
||||
else:
|
||||
|
||||
This function is supposed to be used as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
old_parameters = {
|
||||
n: p.detach().clone() for n, p in model.named_parameters()
|
||||
}
|
||||
|
||||
optimizer.step()
|
||||
|
||||
deltas = optim_step_and_measure_param_change(old_parameters)
|
||||
|
||||
Args:
|
||||
model: A torch.nn.Module instance.
|
||||
old_parameters:
|
||||
A Dict of named_parameters before optimizer.step().
|
||||
|
||||
Return:
|
||||
A Dict containing the relative change for each parameter.
|
||||
"""
|
||||
relative_change = {}
|
||||
with torch.no_grad():
|
||||
for n, p_new in model.named_parameters():
|
||||
p_orig = param_copy[n]
|
||||
p_orig = old_parameters[n]
|
||||
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig)
|
||||
relative_change[n] = delta.item()
|
||||
return relative_change
|
||||
|
Loading…
x
Reference in New Issue
Block a user