mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Incorporate some latest changes to optim.py
(#1359)
* init commit * black formatted * isort formatted
This commit is contained in:
parent
23913f6afd
commit
9e5a5d7839
@ -22,7 +22,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch import Tensor
|
||||
from torch import Tensor, nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer):
|
||||
|
||||
yield tuples # <-- calling code will do the actual optimization here!
|
||||
|
||||
for (stacked_params, _state, _names), batch in zip(tuples, batches):
|
||||
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
||||
for i, p in enumerate(batch): # batch is list of Parameter
|
||||
p.copy_(stacked_params[i])
|
||||
|
||||
@ -181,6 +181,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
size_update_period=4,
|
||||
clipping_update_period=100,
|
||||
):
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
clipping_scale=clipping_scale,
|
||||
@ -326,7 +327,9 @@ class ScaledAdam(BatchedOptimizer):
|
||||
batch = True
|
||||
|
||||
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
||||
|
||||
with self.batched_params(group["params"], group_params_names) as batches:
|
||||
|
||||
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||
# a stacking dim, it is not a real dim.
|
||||
@ -423,16 +426,19 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# parameters' state won't have been initialized yet.
|
||||
return 1.0
|
||||
clipping_update_period = group["clipping_update_period"]
|
||||
scalar_lr_scale = group["scalar_lr_scale"]
|
||||
|
||||
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||
for p, state, param_names in tuples:
|
||||
for (p, state, param_names) in tuples:
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"ScaledAdam optimizer does not support sparse gradients"
|
||||
)
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
|
||||
tot_sumsq += (grad**2).sum() * (
|
||||
scalar_lr_scale**2
|
||||
) # sum() to change shape [1] to []
|
||||
else:
|
||||
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
|
||||
|
||||
@ -443,64 +449,72 @@ class ScaledAdam(BatchedOptimizer):
|
||||
)
|
||||
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
||||
|
||||
if step % clipping_update_period == 0:
|
||||
irregular_estimate_steps = [
|
||||
i for i in [10, 20, 40] if i < clipping_update_period
|
||||
]
|
||||
if step % clipping_update_period == 0 or step in irregular_estimate_steps:
|
||||
# Print some stats.
|
||||
# We don't reach here if step == 0 because we would have returned
|
||||
# above.
|
||||
sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
|
||||
if step in irregular_estimate_steps:
|
||||
sorted_norms = sorted_norms[-step:]
|
||||
num_norms = sorted_norms.numel()
|
||||
quartiles = []
|
||||
for n in range(0, 5):
|
||||
index = min(
|
||||
clipping_update_period - 1, (clipping_update_period // 4) * n
|
||||
)
|
||||
index = min(num_norms - 1, (num_norms // 4) * n)
|
||||
quartiles.append(sorted_norms[index].item())
|
||||
|
||||
median = quartiles[2]
|
||||
threshold = clipping_scale * median
|
||||
if step in irregular_estimate_steps:
|
||||
# use larger thresholds on first few steps of estimating threshold,
|
||||
# as norm may be changing rapidly.
|
||||
threshold = threshold * 2.0
|
||||
first_state["model_norm_threshold"] = threshold
|
||||
percent_clipped = (
|
||||
first_state["num_clipped"] * 100.0 / clipping_update_period
|
||||
first_state["num_clipped"] * 100.0 / num_norms
|
||||
if "num_clipped" in first_state
|
||||
else 0.0
|
||||
)
|
||||
first_state["num_clipped"] = 0
|
||||
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
||||
logging.info(
|
||||
logging.warn(
|
||||
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
||||
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
||||
)
|
||||
|
||||
if step < clipping_update_period:
|
||||
return 1.0 # We have not yet estimated a norm to clip to.
|
||||
else:
|
||||
try:
|
||||
model_norm_threshold = first_state["model_norm_threshold"]
|
||||
except KeyError:
|
||||
logging.info(
|
||||
"Warning: model_norm_threshold not in state: possibly "
|
||||
"you changed config when restarting, adding clipping_scale option?"
|
||||
)
|
||||
return 1.0
|
||||
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
||||
if ans < 1.0:
|
||||
first_state["num_clipped"] += 1
|
||||
if ans < 0.1:
|
||||
logging.warn(
|
||||
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
||||
)
|
||||
if self.show_dominant_parameters:
|
||||
assert p.shape[0] == len(param_names)
|
||||
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
||||
if ans != ans: # e.g. ans is nan
|
||||
ans = 0.0
|
||||
if ans == 0.0:
|
||||
for p, state, param_names in tuples:
|
||||
p.grad.zero_() # get rid of infinity()
|
||||
try:
|
||||
model_norm_threshold = first_state["model_norm_threshold"]
|
||||
except KeyError:
|
||||
return 1.0 # threshold has not yet been set.
|
||||
|
||||
return ans
|
||||
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
||||
if ans != ans: # e.g. ans is nan
|
||||
ans = 0.0
|
||||
if ans < 1.0:
|
||||
first_state["num_clipped"] += 1
|
||||
if ans < 0.1:
|
||||
logging.warn(
|
||||
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
||||
)
|
||||
if self.show_dominant_parameters:
|
||||
assert p.shape[0] == len(param_names)
|
||||
self._show_gradient_dominating_parameter(
|
||||
tuples, tot_sumsq, group["scalar_lr_scale"]
|
||||
)
|
||||
|
||||
if ans == 0.0:
|
||||
for (p, state, param_names) in tuples:
|
||||
p.grad.zero_() # get rid of infinity()
|
||||
|
||||
return ans
|
||||
|
||||
def _show_gradient_dominating_parameter(
|
||||
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
|
||||
self,
|
||||
tuples: List[Tuple[Tensor, dict, List[str]]],
|
||||
tot_sumsq: Tensor,
|
||||
scalar_lr_scale: float,
|
||||
):
|
||||
"""
|
||||
Show information of parameter which dominates tot_sumsq.
|
||||
@ -516,29 +530,30 @@ class ScaledAdam(BatchedOptimizer):
|
||||
from tuples, we still pass it to save some time.
|
||||
"""
|
||||
all_sumsq_orig = {}
|
||||
for p, state, batch_param_names in tuples:
|
||||
for (p, state, batch_param_names) in tuples:
|
||||
# p is a stacked batch parameters.
|
||||
batch_grad = p.grad
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
batch_sumsq_orig = batch_grad**2
|
||||
# Dummy values used by following `zip` statement.
|
||||
batch_rms_orig = torch.ones(p.shape[0])
|
||||
batch_rms_orig = torch.full(
|
||||
p.shape, scalar_lr_scale, device=batch_grad.device
|
||||
)
|
||||
else:
|
||||
batch_rms_orig = state["param_rms"]
|
||||
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
|
||||
batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2
|
||||
if batch_grad.ndim > 1:
|
||||
# need to guard it with if-statement because sum() sums over
|
||||
# all dims if dim == ().
|
||||
batch_sumsq_orig = batch_sumsq_orig.sum(
|
||||
dim=list(range(1, batch_grad.ndim))
|
||||
)
|
||||
|
||||
for name, sumsq_orig, rms, grad in zip(
|
||||
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
||||
):
|
||||
|
||||
proportion_orig = sumsq_orig / tot_sumsq
|
||||
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
||||
|
||||
assert torch.isclose(
|
||||
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
|
||||
torch.tensor(1.0),
|
||||
)
|
||||
sorted_by_proportion = {
|
||||
k: v
|
||||
for k, v in sorted(
|
||||
@ -552,7 +567,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
dominant_rms,
|
||||
dominant_grad,
|
||||
) = sorted_by_proportion[dominant_param_name]
|
||||
logging.info(
|
||||
logging.warn(
|
||||
f"Parameter dominating tot_sumsq {dominant_param_name}"
|
||||
f" with proportion {dominant_proportion:.2f},"
|
||||
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||
@ -826,7 +841,7 @@ class LRScheduler(object):
|
||||
def print_lr(self, is_verbose, group, lr):
|
||||
"""Display the current learning rate."""
|
||||
if is_verbose:
|
||||
logging.info(
|
||||
logging.warn(
|
||||
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
||||
f" of group {group} to {lr:.4e}."
|
||||
)
|
||||
@ -841,8 +856,14 @@ class Eden(LRScheduler):
|
||||
where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
|
||||
and then stays constant at 1.
|
||||
|
||||
If you don't have the concept of epochs, or one epoch takes a very long time,
|
||||
you can replace the notion of 'epoch' with some measure of the amount of data
|
||||
processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to
|
||||
some measure representing "quite a lot of data": say, one fifth or one third
|
||||
of an entire training run, but it doesn't matter much. You could also use
|
||||
Eden2 which has only the notion of batches.
|
||||
|
||||
E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
|
||||
We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
|
||||
|
||||
Args:
|
||||
optimizer: the optimizer to change the learning rates on
|
||||
@ -888,6 +909,56 @@ class Eden(LRScheduler):
|
||||
return [x * factor * warmup_factor for x in self.base_lrs]
|
||||
|
||||
|
||||
class Eden2(LRScheduler):
|
||||
"""
|
||||
Eden2 scheduler, simpler than Eden because it does not use the notion of epoch,
|
||||
only batches.
|
||||
|
||||
The basic formula (before warmup) is:
|
||||
lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.5) * warmup
|
||||
|
||||
where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
|
||||
and then stays constant at 1.
|
||||
|
||||
|
||||
E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
|
||||
|
||||
Args:
|
||||
optimizer: the optimizer to change the learning rates on
|
||||
lr_batches: the number of batches after which we start significantly
|
||||
decreasing the learning rate, suggest 5000.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
lr_batches: Union[int, float],
|
||||
warmup_batches: Union[int, float] = 500.0,
|
||||
warmup_start: float = 0.5,
|
||||
verbose: bool = False,
|
||||
):
|
||||
super().__init__(optimizer, verbose)
|
||||
self.lr_batches = lr_batches
|
||||
self.warmup_batches = warmup_batches
|
||||
|
||||
assert 0.0 <= warmup_start <= 1.0, warmup_start
|
||||
self.warmup_start = warmup_start
|
||||
|
||||
def get_lr(self):
|
||||
factor = (
|
||||
(self.batch**2 + self.lr_batches**2) / self.lr_batches**2
|
||||
) ** -0.5
|
||||
warmup_factor = (
|
||||
1.0
|
||||
if self.batch >= self.warmup_batches
|
||||
else self.warmup_start
|
||||
+ (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
|
||||
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
||||
)
|
||||
|
||||
return [x * factor * warmup_factor for x in self.base_lrs]
|
||||
|
||||
|
||||
def _test_eden():
|
||||
m = torch.nn.Linear(100, 100)
|
||||
optim = ScaledAdam(m.parameters(), lr=0.03)
|
||||
|
Loading…
x
Reference in New Issue
Block a user