mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Incorporate some latest changes to optim.py
(#1359)
* init commit * black formatted * isort formatted
This commit is contained in:
parent
23913f6afd
commit
9e5a5d7839
@ -22,7 +22,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch import Tensor
|
from torch import Tensor, nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer):
|
|||||||
|
|
||||||
yield tuples # <-- calling code will do the actual optimization here!
|
yield tuples # <-- calling code will do the actual optimization here!
|
||||||
|
|
||||||
for (stacked_params, _state, _names), batch in zip(tuples, batches):
|
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
||||||
for i, p in enumerate(batch): # batch is list of Parameter
|
for i, p in enumerate(batch): # batch is list of Parameter
|
||||||
p.copy_(stacked_params[i])
|
p.copy_(stacked_params[i])
|
||||||
|
|
||||||
@ -181,6 +181,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
size_update_period=4,
|
size_update_period=4,
|
||||||
clipping_update_period=100,
|
clipping_update_period=100,
|
||||||
):
|
):
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
clipping_scale=clipping_scale,
|
clipping_scale=clipping_scale,
|
||||||
@ -326,7 +327,9 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
batch = True
|
batch = True
|
||||||
|
|
||||||
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
||||||
|
|
||||||
with self.batched_params(group["params"], group_params_names) as batches:
|
with self.batched_params(group["params"], group_params_names) as batches:
|
||||||
|
|
||||||
# batches is list of pairs (stacked_param, state). stacked_param is like
|
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||||
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||||
# a stacking dim, it is not a real dim.
|
# a stacking dim, it is not a real dim.
|
||||||
@ -423,16 +426,19 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
# parameters' state won't have been initialized yet.
|
# parameters' state won't have been initialized yet.
|
||||||
return 1.0
|
return 1.0
|
||||||
clipping_update_period = group["clipping_update_period"]
|
clipping_update_period = group["clipping_update_period"]
|
||||||
|
scalar_lr_scale = group["scalar_lr_scale"]
|
||||||
|
|
||||||
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||||
for p, state, param_names in tuples:
|
for (p, state, param_names) in tuples:
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"ScaledAdam optimizer does not support sparse gradients"
|
"ScaledAdam optimizer does not support sparse gradients"
|
||||||
)
|
)
|
||||||
if p.numel() == p.shape[0]: # a batch of scalars
|
if p.numel() == p.shape[0]: # a batch of scalars
|
||||||
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
|
tot_sumsq += (grad**2).sum() * (
|
||||||
|
scalar_lr_scale**2
|
||||||
|
) # sum() to change shape [1] to []
|
||||||
else:
|
else:
|
||||||
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
|
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
|
||||||
|
|
||||||
@ -443,64 +449,72 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
)
|
)
|
||||||
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
||||||
|
|
||||||
if step % clipping_update_period == 0:
|
irregular_estimate_steps = [
|
||||||
|
i for i in [10, 20, 40] if i < clipping_update_period
|
||||||
|
]
|
||||||
|
if step % clipping_update_period == 0 or step in irregular_estimate_steps:
|
||||||
# Print some stats.
|
# Print some stats.
|
||||||
# We don't reach here if step == 0 because we would have returned
|
# We don't reach here if step == 0 because we would have returned
|
||||||
# above.
|
# above.
|
||||||
sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
|
sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
|
||||||
|
if step in irregular_estimate_steps:
|
||||||
|
sorted_norms = sorted_norms[-step:]
|
||||||
|
num_norms = sorted_norms.numel()
|
||||||
quartiles = []
|
quartiles = []
|
||||||
for n in range(0, 5):
|
for n in range(0, 5):
|
||||||
index = min(
|
index = min(num_norms - 1, (num_norms // 4) * n)
|
||||||
clipping_update_period - 1, (clipping_update_period // 4) * n
|
|
||||||
)
|
|
||||||
quartiles.append(sorted_norms[index].item())
|
quartiles.append(sorted_norms[index].item())
|
||||||
|
|
||||||
median = quartiles[2]
|
median = quartiles[2]
|
||||||
threshold = clipping_scale * median
|
threshold = clipping_scale * median
|
||||||
|
if step in irregular_estimate_steps:
|
||||||
|
# use larger thresholds on first few steps of estimating threshold,
|
||||||
|
# as norm may be changing rapidly.
|
||||||
|
threshold = threshold * 2.0
|
||||||
first_state["model_norm_threshold"] = threshold
|
first_state["model_norm_threshold"] = threshold
|
||||||
percent_clipped = (
|
percent_clipped = (
|
||||||
first_state["num_clipped"] * 100.0 / clipping_update_period
|
first_state["num_clipped"] * 100.0 / num_norms
|
||||||
if "num_clipped" in first_state
|
if "num_clipped" in first_state
|
||||||
else 0.0
|
else 0.0
|
||||||
)
|
)
|
||||||
first_state["num_clipped"] = 0
|
first_state["num_clipped"] = 0
|
||||||
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
||||||
logging.info(
|
logging.warn(
|
||||||
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
||||||
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if step < clipping_update_period:
|
try:
|
||||||
return 1.0 # We have not yet estimated a norm to clip to.
|
model_norm_threshold = first_state["model_norm_threshold"]
|
||||||
else:
|
except KeyError:
|
||||||
try:
|
return 1.0 # threshold has not yet been set.
|
||||||
model_norm_threshold = first_state["model_norm_threshold"]
|
|
||||||
except KeyError:
|
|
||||||
logging.info(
|
|
||||||
"Warning: model_norm_threshold not in state: possibly "
|
|
||||||
"you changed config when restarting, adding clipping_scale option?"
|
|
||||||
)
|
|
||||||
return 1.0
|
|
||||||
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
|
||||||
if ans < 1.0:
|
|
||||||
first_state["num_clipped"] += 1
|
|
||||||
if ans < 0.1:
|
|
||||||
logging.warn(
|
|
||||||
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
|
||||||
)
|
|
||||||
if self.show_dominant_parameters:
|
|
||||||
assert p.shape[0] == len(param_names)
|
|
||||||
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
|
||||||
if ans != ans: # e.g. ans is nan
|
|
||||||
ans = 0.0
|
|
||||||
if ans == 0.0:
|
|
||||||
for p, state, param_names in tuples:
|
|
||||||
p.grad.zero_() # get rid of infinity()
|
|
||||||
|
|
||||||
return ans
|
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
||||||
|
if ans != ans: # e.g. ans is nan
|
||||||
|
ans = 0.0
|
||||||
|
if ans < 1.0:
|
||||||
|
first_state["num_clipped"] += 1
|
||||||
|
if ans < 0.1:
|
||||||
|
logging.warn(
|
||||||
|
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
||||||
|
)
|
||||||
|
if self.show_dominant_parameters:
|
||||||
|
assert p.shape[0] == len(param_names)
|
||||||
|
self._show_gradient_dominating_parameter(
|
||||||
|
tuples, tot_sumsq, group["scalar_lr_scale"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if ans == 0.0:
|
||||||
|
for (p, state, param_names) in tuples:
|
||||||
|
p.grad.zero_() # get rid of infinity()
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
def _show_gradient_dominating_parameter(
|
def _show_gradient_dominating_parameter(
|
||||||
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
|
self,
|
||||||
|
tuples: List[Tuple[Tensor, dict, List[str]]],
|
||||||
|
tot_sumsq: Tensor,
|
||||||
|
scalar_lr_scale: float,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Show information of parameter which dominates tot_sumsq.
|
Show information of parameter which dominates tot_sumsq.
|
||||||
@ -516,29 +530,30 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
from tuples, we still pass it to save some time.
|
from tuples, we still pass it to save some time.
|
||||||
"""
|
"""
|
||||||
all_sumsq_orig = {}
|
all_sumsq_orig = {}
|
||||||
for p, state, batch_param_names in tuples:
|
for (p, state, batch_param_names) in tuples:
|
||||||
# p is a stacked batch parameters.
|
# p is a stacked batch parameters.
|
||||||
batch_grad = p.grad
|
batch_grad = p.grad
|
||||||
if p.numel() == p.shape[0]: # a batch of scalars
|
if p.numel() == p.shape[0]: # a batch of scalars
|
||||||
batch_sumsq_orig = batch_grad**2
|
|
||||||
# Dummy values used by following `zip` statement.
|
# Dummy values used by following `zip` statement.
|
||||||
batch_rms_orig = torch.ones(p.shape[0])
|
batch_rms_orig = torch.full(
|
||||||
|
p.shape, scalar_lr_scale, device=batch_grad.device
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
batch_rms_orig = state["param_rms"]
|
batch_rms_orig = state["param_rms"]
|
||||||
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
|
batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2
|
||||||
|
if batch_grad.ndim > 1:
|
||||||
|
# need to guard it with if-statement because sum() sums over
|
||||||
|
# all dims if dim == ().
|
||||||
|
batch_sumsq_orig = batch_sumsq_orig.sum(
|
||||||
dim=list(range(1, batch_grad.ndim))
|
dim=list(range(1, batch_grad.ndim))
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, sumsq_orig, rms, grad in zip(
|
for name, sumsq_orig, rms, grad in zip(
|
||||||
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
||||||
):
|
):
|
||||||
|
|
||||||
proportion_orig = sumsq_orig / tot_sumsq
|
proportion_orig = sumsq_orig / tot_sumsq
|
||||||
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
||||||
|
|
||||||
assert torch.isclose(
|
|
||||||
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
|
|
||||||
torch.tensor(1.0),
|
|
||||||
)
|
|
||||||
sorted_by_proportion = {
|
sorted_by_proportion = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in sorted(
|
for k, v in sorted(
|
||||||
@ -552,7 +567,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
dominant_rms,
|
dominant_rms,
|
||||||
dominant_grad,
|
dominant_grad,
|
||||||
) = sorted_by_proportion[dominant_param_name]
|
) = sorted_by_proportion[dominant_param_name]
|
||||||
logging.info(
|
logging.warn(
|
||||||
f"Parameter dominating tot_sumsq {dominant_param_name}"
|
f"Parameter dominating tot_sumsq {dominant_param_name}"
|
||||||
f" with proportion {dominant_proportion:.2f},"
|
f" with proportion {dominant_proportion:.2f},"
|
||||||
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||||
@ -826,7 +841,7 @@ class LRScheduler(object):
|
|||||||
def print_lr(self, is_verbose, group, lr):
|
def print_lr(self, is_verbose, group, lr):
|
||||||
"""Display the current learning rate."""
|
"""Display the current learning rate."""
|
||||||
if is_verbose:
|
if is_verbose:
|
||||||
logging.info(
|
logging.warn(
|
||||||
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
||||||
f" of group {group} to {lr:.4e}."
|
f" of group {group} to {lr:.4e}."
|
||||||
)
|
)
|
||||||
@ -841,8 +856,14 @@ class Eden(LRScheduler):
|
|||||||
where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
|
where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
|
||||||
and then stays constant at 1.
|
and then stays constant at 1.
|
||||||
|
|
||||||
|
If you don't have the concept of epochs, or one epoch takes a very long time,
|
||||||
|
you can replace the notion of 'epoch' with some measure of the amount of data
|
||||||
|
processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to
|
||||||
|
some measure representing "quite a lot of data": say, one fifth or one third
|
||||||
|
of an entire training run, but it doesn't matter much. You could also use
|
||||||
|
Eden2 which has only the notion of batches.
|
||||||
|
|
||||||
E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
|
We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer: the optimizer to change the learning rates on
|
optimizer: the optimizer to change the learning rates on
|
||||||
@ -888,6 +909,56 @@ class Eden(LRScheduler):
|
|||||||
return [x * factor * warmup_factor for x in self.base_lrs]
|
return [x * factor * warmup_factor for x in self.base_lrs]
|
||||||
|
|
||||||
|
|
||||||
|
class Eden2(LRScheduler):
|
||||||
|
"""
|
||||||
|
Eden2 scheduler, simpler than Eden because it does not use the notion of epoch,
|
||||||
|
only batches.
|
||||||
|
|
||||||
|
The basic formula (before warmup) is:
|
||||||
|
lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.5) * warmup
|
||||||
|
|
||||||
|
where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
|
||||||
|
and then stays constant at 1.
|
||||||
|
|
||||||
|
|
||||||
|
E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: the optimizer to change the learning rates on
|
||||||
|
lr_batches: the number of batches after which we start significantly
|
||||||
|
decreasing the learning rate, suggest 5000.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
lr_batches: Union[int, float],
|
||||||
|
warmup_batches: Union[int, float] = 500.0,
|
||||||
|
warmup_start: float = 0.5,
|
||||||
|
verbose: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(optimizer, verbose)
|
||||||
|
self.lr_batches = lr_batches
|
||||||
|
self.warmup_batches = warmup_batches
|
||||||
|
|
||||||
|
assert 0.0 <= warmup_start <= 1.0, warmup_start
|
||||||
|
self.warmup_start = warmup_start
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
factor = (
|
||||||
|
(self.batch**2 + self.lr_batches**2) / self.lr_batches**2
|
||||||
|
) ** -0.5
|
||||||
|
warmup_factor = (
|
||||||
|
1.0
|
||||||
|
if self.batch >= self.warmup_batches
|
||||||
|
else self.warmup_start
|
||||||
|
+ (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
|
||||||
|
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
||||||
|
)
|
||||||
|
|
||||||
|
return [x * factor * warmup_factor for x in self.base_lrs]
|
||||||
|
|
||||||
|
|
||||||
def _test_eden():
|
def _test_eden():
|
||||||
m = torch.nn.Linear(100, 100)
|
m = torch.nn.Linear(100, 100)
|
||||||
optim = ScaledAdam(m.parameters(), lr=0.03)
|
optim = ScaledAdam(m.parameters(), lr=0.03)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user