mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
parent
08473a17aa
commit
f721a2fd7a
@ -609,21 +609,6 @@ def train_one_epoch(
|
|||||||
global_step=params.batch_idx_train,
|
global_step=params.batch_idx_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
def maybe_log_param_relative_changes():
|
|
||||||
if (
|
|
||||||
params.log_diagnostics
|
|
||||||
and tb_writer is not None
|
|
||||||
and params.batch_idx_train % (params.log_interval * 5) == 0
|
|
||||||
):
|
|
||||||
deltas = optim_step_and_measure_param_change(model, optimizer)
|
|
||||||
tb_writer.add_scalars(
|
|
||||||
"train/relative_param_change_per_minibatch",
|
|
||||||
deltas,
|
|
||||||
global_step=params.batch_idx_train,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
@ -651,7 +636,26 @@ def train_one_epoch(
|
|||||||
|
|
||||||
maybe_log_weights("train/param_norms")
|
maybe_log_weights("train/param_norms")
|
||||||
maybe_log_gradients("train/grad_norms")
|
maybe_log_gradients("train/grad_norms")
|
||||||
maybe_log_param_relative_changes()
|
|
||||||
|
old_parameters = None
|
||||||
|
if (
|
||||||
|
params.log_diagnostics
|
||||||
|
and tb_writer is not None
|
||||||
|
and params.batch_idx_train % (params.log_interval * 5) == 0
|
||||||
|
):
|
||||||
|
old_parameters = {
|
||||||
|
n: p.detach().clone() for n, p in model.named_parameters()
|
||||||
|
}
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
if old_parameters is not None:
|
||||||
|
deltas = optim_step_and_measure_param_change(model, old_parameters)
|
||||||
|
tb_writer.add_scalars(
|
||||||
|
"train/relative_param_change_per_minibatch",
|
||||||
|
deltas,
|
||||||
|
global_step=params.batch_idx_train,
|
||||||
|
)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
@ -25,15 +25,14 @@ from collections import defaultdict
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterable, List, TextIO, Optional, Tuple, Union
|
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import k2.version
|
import k2.version
|
||||||
import kaldialign
|
import kaldialign
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.cuda.amp import GradScaler
|
import torch.nn as nn
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
Pathlike = Union[str, Path]
|
Pathlike = Union[str, Path]
|
||||||
@ -758,11 +757,10 @@ def measure_gradient_norms(
|
|||||||
|
|
||||||
def optim_step_and_measure_param_change(
|
def optim_step_and_measure_param_change(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: torch.optim.Optimizer,
|
old_parameters: Dict[str, nn.parameter.Parameter],
|
||||||
scaler: Optional[GradScaler] = None,
|
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Perform model weight update and measure the "relative change in parameters per minibatch."
|
Measure the "relative change in parameters per minibatch."
|
||||||
It is understood as a ratio between the L2 norm of the difference between original and updates parameters,
|
It is understood as a ratio between the L2 norm of the difference between original and updates parameters,
|
||||||
and the L2 norm of the original parameter. It is given by the formula:
|
and the L2 norm of the original parameter. It is given by the formula:
|
||||||
|
|
||||||
@ -770,16 +768,31 @@ def optim_step_and_measure_param_change(
|
|||||||
\begin{aligned}
|
\begin{aligned}
|
||||||
\delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2}
|
\delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2}
|
||||||
\end{aligned}
|
\end{aligned}
|
||||||
"""
|
|
||||||
param_copy = {n: p.detach().clone() for n, p in model.named_parameters()}
|
This function is supposed to be used as follows:
|
||||||
if scaler:
|
|
||||||
scaler.step(optimizer)
|
.. code-block:: python
|
||||||
else:
|
|
||||||
|
old_parameters = {
|
||||||
|
n: p.detach().clone() for n, p in model.named_parameters()
|
||||||
|
}
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
deltas = optim_step_and_measure_param_change(old_parameters)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: A torch.nn.Module instance.
|
||||||
|
old_parameters:
|
||||||
|
A Dict of named_parameters before optimizer.step().
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A Dict containing the relative change for each parameter.
|
||||||
|
"""
|
||||||
relative_change = {}
|
relative_change = {}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for n, p_new in model.named_parameters():
|
for n, p_new in model.named_parameters():
|
||||||
p_orig = param_copy[n]
|
p_orig = old_parameters[n]
|
||||||
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig)
|
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig)
|
||||||
relative_change[n] = delta.item()
|
relative_change[n] = delta.item()
|
||||||
return relative_change
|
return relative_change
|
||||||
|
Loading…
x
Reference in New Issue
Block a user