mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge debugging changes to optimizer.
This commit is contained in:
parent
b546ac866c
commit
28cac1c2dc
@ -14,17 +14,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Union, Tuple, List
|
||||
from lhotse.utils import fix_random_seed
|
||||
import torch
|
||||
from scaling import ActivationBalancer
|
||||
import contextlib
|
||||
import logging
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from lhotse.utils import fix_random_seed
|
||||
from scaling import ActivationBalancer
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
import logging
|
||||
import contextlib
|
||||
|
||||
|
||||
|
||||
class BatchedOptimizer(Optimizer):
|
||||
@ -37,13 +37,12 @@ class BatchedOptimizer(Optimizer):
|
||||
Args:
|
||||
params:
|
||||
"""
|
||||
|
||||
def __init__(self, params, defaults):
|
||||
super(BatchedOptimizer, self).__init__(params, defaults)
|
||||
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def batched_params(self, param_group):
|
||||
def batched_params(self, param_group, group_params_names):
|
||||
"""
|
||||
This function returns (technically, yields) a list of
|
||||
of tuples (p, state), where
|
||||
@ -65,48 +64,64 @@ class BatchedOptimizer(Optimizer):
|
||||
you can do:
|
||||
<code>
|
||||
with self.batched_params(group["params"]) as batches:
|
||||
for p, state in batches:
|
||||
for p, state, p_names in batches:
|
||||
...
|
||||
</code>
|
||||
|
||||
Args:
|
||||
group: a parameter group, which is a list of parameters; should be
|
||||
one of self.groups.
|
||||
one of self.param_groups.
|
||||
group_params_names: name for each parameter in group,
|
||||
which is List[str].
|
||||
"""
|
||||
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
||||
batches = defaultdict(
|
||||
list
|
||||
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
||||
batches_names = defaultdict(
|
||||
list
|
||||
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
||||
|
||||
for p in param_group:
|
||||
assert len(param_group) == len(group_params_names)
|
||||
for p, named_p in zip(param_group, group_params_names):
|
||||
key = (str(p.dtype), *p.shape)
|
||||
batches[key].append(p)
|
||||
batches_names[key].append(named_p)
|
||||
|
||||
batches_names_keys = list(batches_names.keys())
|
||||
sorted_idx = sorted(
|
||||
range(len(batches_names)), key=lambda i: batches_names_keys[i]
|
||||
)
|
||||
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
|
||||
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
||||
|
||||
stacked_params_dict = dict()
|
||||
|
||||
# turn batches into a list, in deterministic order.
|
||||
batches = [ batches[key] for key in sorted(batches.keys()) ]
|
||||
# pairs will contain pairs of (stacked_param, state), one for each batch
|
||||
# in `batches`.
|
||||
pairs = []
|
||||
# tuples will contain tuples of (stacked_param, state, stacked_params_names),
|
||||
# one for each batch in `batches`.
|
||||
tuples = []
|
||||
|
||||
for batch in batches:
|
||||
for batch, batch_names in zip(batches, batches_names):
|
||||
p = batch[0]
|
||||
# we arbitrarily store the state in the
|
||||
# state corresponding to the 1st parameter in the
|
||||
# group. class Optimizer will take care of saving/loading state.
|
||||
state = self.state[p]
|
||||
p_stacked = torch.stack(batch)
|
||||
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ])
|
||||
grad = torch.stack(
|
||||
[torch.zeros_like(p) if p.grad is None else p.grad for p in batch]
|
||||
)
|
||||
p_stacked.grad = grad
|
||||
stacked_params_dict[key] = p_stacked
|
||||
pairs.append((p_stacked, state))
|
||||
tuples.append((p_stacked, state, batch_names))
|
||||
|
||||
yield pairs # <-- calling code will do the actual optimization here!
|
||||
yield tuples # <-- calling code will do the actual optimization here!
|
||||
|
||||
for ((stacked_params, _state), batch) in zip(pairs, batches):
|
||||
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
||||
for i, p in enumerate(batch): # batch is list of Parameter
|
||||
p.copy_(stacked_params[i])
|
||||
|
||||
|
||||
|
||||
class ScaledAdam(BatchedOptimizer):
|
||||
"""
|
||||
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
||||
@ -149,6 +164,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
in the update.
|
||||
clipping_update_period: if clipping_scale is specified, this is the period
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
@ -162,9 +178,15 @@ class ScaledAdam(BatchedOptimizer):
|
||||
scalar_max=10.0,
|
||||
size_update_period=4,
|
||||
clipping_update_period=100,
|
||||
parameters_names=None,
|
||||
show_dominant_parameters=True,
|
||||
):
|
||||
|
||||
|
||||
assert parameters_names is not None, (
|
||||
"Please prepare parameters_names,"
|
||||
"which is a List[List[str]]. Each List[str] is for a group"
|
||||
"and each str is for a parameter"
|
||||
)
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
clipping_scale=clipping_scale,
|
||||
@ -179,11 +201,13 @@ class ScaledAdam(BatchedOptimizer):
|
||||
)
|
||||
|
||||
super(ScaledAdam, self).__init__(params, defaults)
|
||||
assert len(self.param_groups) == len(parameters_names)
|
||||
self.parameters_names = parameters_names
|
||||
self.show_dominant_parameters = show_dominant_parameters
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(ScaledAdam, self).__setstate__(state)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
@ -198,20 +222,23 @@ class ScaledAdam(BatchedOptimizer):
|
||||
loss = closure()
|
||||
|
||||
batch = True
|
||||
for group in self.param_groups:
|
||||
|
||||
with self.batched_params(group["params"]) as batches:
|
||||
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
||||
|
||||
with self.batched_params(group["params"], group_params_names) as batches:
|
||||
|
||||
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||
# a stacking dim, it is not a real dim.
|
||||
|
||||
if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
|
||||
if (
|
||||
len(batches[0][1]) == 0
|
||||
): # if len(first state) == 0: not yet initialized
|
||||
clipping_scale = 1
|
||||
else:
|
||||
clipping_scale = self._get_clipping_scale(group, batches)
|
||||
|
||||
for p, state in batches:
|
||||
for p, state, _ in batches:
|
||||
# Perform optimization step.
|
||||
# grad is not going to be None, we handled that when creating the batches.
|
||||
grad = p.grad
|
||||
@ -225,13 +252,9 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
self._step_one_batch(group, p, state, clipping_scale)
|
||||
|
||||
|
||||
return loss
|
||||
|
||||
def _init_state(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict):
|
||||
def _init_state(self, group: dict, p: Tensor, state: dict):
|
||||
"""
|
||||
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
|
||||
is actually the batch dimension, corresponding to batched-together
|
||||
@ -247,7 +270,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
state["step"] = 0
|
||||
|
||||
kwargs = {'device':p.device, 'dtype':p.dtype}
|
||||
kwargs = {"device": p.device, "dtype": p.dtype}
|
||||
|
||||
# 'delta' implements conventional momentum. There are
|
||||
# several different kinds of update going on, so rather than
|
||||
@ -255,48 +278,46 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# parameter-change "delta", which combines all forms of
|
||||
# update. this is equivalent to how it's done in Adam,
|
||||
# except for the first few steps.
|
||||
state["delta"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
|
||||
batch_size = p.shape[0]
|
||||
numel = p.numel() // batch_size
|
||||
numel = p.numel()
|
||||
|
||||
|
||||
if numel > 1:
|
||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||
# the parameter tensor.
|
||||
# it has a shape like (batch_size, 1, 1, 1, 1)
|
||||
param_rms = (p**2).mean(dim=list(range(1, p.ndim)),
|
||||
keepdim=True).sqrt()
|
||||
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
||||
state["param_rms"] = param_rms
|
||||
|
||||
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
||||
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape,
|
||||
**kwargs)
|
||||
|
||||
|
||||
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
state["scale_grads"] = torch.zeros(
|
||||
size_update_period, *param_rms.shape, **kwargs
|
||||
)
|
||||
|
||||
def _get_clipping_scale(self,
|
||||
group: dict,
|
||||
pairs: List[Tuple[Tensor, dict]]) -> float:
|
||||
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
||||
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
|
||||
def _get_clipping_scale(
|
||||
self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
|
||||
) -> float:
|
||||
"""
|
||||
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
||||
by this amount before applying the rest of the update.
|
||||
|
||||
Args:
|
||||
group: the parameter group, an item in self.param_groups
|
||||
pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad
|
||||
(1st dim is batch dim) and state is the state-dict where optimization parameters are kept.
|
||||
tuples: a list of tuples of (param, state, param_names)
|
||||
where param is a batched set of parameters,
|
||||
with a .grad (1st dim is batch dim)
|
||||
and state is the state-dict where optimization parameters are kept.
|
||||
param_names is a List[str] while each str is name for a parameter
|
||||
in batched set of parameters "param".
|
||||
"""
|
||||
assert len(pairs) >= 1
|
||||
assert len(tuples) >= 1
|
||||
clipping_scale = group["clipping_scale"]
|
||||
(first_p, first_state) = pairs[0]
|
||||
(first_p, first_state, _) = tuples[0]
|
||||
step = first_state["step"]
|
||||
if clipping_scale is None or step == 0:
|
||||
# no clipping. return early on step == 0 because the other
|
||||
@ -305,7 +326,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
clipping_update_period = group["clipping_update_period"]
|
||||
|
||||
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||
for (p, state) in pairs:
|
||||
for (p, state, param_names) in tuples:
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
@ -317,54 +338,128 @@ class ScaledAdam(BatchedOptimizer):
|
||||
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
|
||||
|
||||
tot_norm = tot_sumsq.sqrt()
|
||||
if not "model_norms" in first_state:
|
||||
first_state["model_norms"] = torch.zeros(clipping_update_period,
|
||||
device=p.device)
|
||||
if "model_norms" not in first_state:
|
||||
first_state["model_norms"] = torch.zeros(
|
||||
clipping_update_period, device=p.device
|
||||
)
|
||||
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
||||
|
||||
if step % clipping_update_period == 0:
|
||||
# Print some stats.
|
||||
# We don't reach here if step == 0 because we would have returned
|
||||
# above.
|
||||
sorted_norms = first_state["model_norms"].sort()[0].to('cpu')
|
||||
sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
|
||||
quartiles = []
|
||||
for n in range(0, 5):
|
||||
index = min(clipping_update_period - 1,
|
||||
(clipping_update_period // 4) * n)
|
||||
index = min(
|
||||
clipping_update_period - 1, (clipping_update_period // 4) * n
|
||||
)
|
||||
quartiles.append(sorted_norms[index].item())
|
||||
|
||||
median = quartiles[2]
|
||||
threshold = clipping_scale * median
|
||||
first_state["model_norm_threshold"] = threshold
|
||||
percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period
|
||||
if "num_clipped" in first_state else 0.0)
|
||||
percent_clipped = (
|
||||
first_state["num_clipped"] * 100.0 / clipping_update_period
|
||||
if "num_clipped" in first_state
|
||||
else 0.0
|
||||
)
|
||||
first_state["num_clipped"] = 0
|
||||
quartiles = ' '.join([ '%.3e' % x for x in quartiles ])
|
||||
logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
||||
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}")
|
||||
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
||||
logging.info(
|
||||
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
||||
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
||||
)
|
||||
|
||||
if step < clipping_update_period:
|
||||
return 1.0 # We have not yet estimated a norm to clip to.
|
||||
else:
|
||||
try:
|
||||
model_norm_threshold = first_state["model_norm_threshold"]
|
||||
except:
|
||||
logging.info("Warning: model_norm_threshold not in state: possibly "
|
||||
"you changed config when restarting, adding clipping_scale option?")
|
||||
except KeyError:
|
||||
logging.info(
|
||||
"Warning: model_norm_threshold not in state: possibly "
|
||||
"you changed config when restarting, adding clipping_scale option?"
|
||||
)
|
||||
return 1.0
|
||||
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
||||
if ans < 1.0:
|
||||
first_state["num_clipped"] += 1
|
||||
if ans < 0.1:
|
||||
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
||||
logging.warn(
|
||||
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
||||
)
|
||||
if self.show_dominant_parameters:
|
||||
assert p.shape[0] == len(param_names)
|
||||
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
||||
return ans
|
||||
|
||||
def _show_gradient_dominating_parameter(
|
||||
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
|
||||
):
|
||||
"""
|
||||
Show information of parameter wihch dominanting tot_sumsq.
|
||||
|
||||
def _step_one_batch(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict,
|
||||
clipping_scale: float):
|
||||
Args:
|
||||
tuples: a list of tuples of (param, state, param_names)
|
||||
where param is a batched set of parameters,
|
||||
with a .grad (1st dim is batch dim)
|
||||
and state is the state-dict where optimization parameters are kept.
|
||||
param_names is a List[str] while each str is name for a parameter
|
||||
in batched set of parameters "param".
|
||||
tot_sumsq: sumsq of all parameters. Though it's could be calculated
|
||||
from tuples, we still pass it to save some time.
|
||||
"""
|
||||
all_sumsq_orig = {}
|
||||
for (p, state, batch_param_names) in tuples:
|
||||
# p is a stacked batch parameters.
|
||||
batch_grad = p.grad
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
batch_sumsq_orig = batch_grad**2
|
||||
# Dummpy values used by following `zip` statement.
|
||||
batch_rms_orig = torch.ones(p.shape[0])
|
||||
else:
|
||||
batch_rms_orig = state["param_rms"]
|
||||
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
|
||||
dim=list(range(1, batch_grad.ndim))
|
||||
)
|
||||
|
||||
for name, sumsq_orig, rms, grad in zip(
|
||||
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
||||
):
|
||||
|
||||
proportion_orig = sumsq_orig / tot_sumsq
|
||||
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
||||
|
||||
assert torch.isclose(
|
||||
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
|
||||
torch.tensor(1.0),
|
||||
)
|
||||
sorted_by_proportion = {
|
||||
k: v
|
||||
for k, v in sorted(
|
||||
all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True
|
||||
)
|
||||
}
|
||||
dominant_param_name = next(iter(sorted_by_proportion))
|
||||
(
|
||||
dominant_proportion,
|
||||
dominant_sumsq,
|
||||
dominant_rms,
|
||||
dominant_grad,
|
||||
) = sorted_by_proportion[dominant_param_name]
|
||||
logging.info(
|
||||
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
||||
f" with proportion {dominant_proportion:.2f},"
|
||||
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||
f"={dominant_sumsq:.3e},"
|
||||
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
||||
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
||||
)
|
||||
|
||||
def _step_one_batch(
|
||||
self, group: dict, p: Tensor, state: dict, clipping_scale: float
|
||||
):
|
||||
"""
|
||||
Do the step for one parameter, which is actually going to be a batch of
|
||||
`real` parameters, with dim 0 as the batch dim.
|
||||
@ -391,17 +486,18 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# Update the size/scale of p, and set param_rms
|
||||
scale_grads = state["scale_grads"]
|
||||
scale_grads[step % size_update_period] = (p * grad).sum(
|
||||
dim=list(range(1, p.ndim)), keepdim=True)
|
||||
dim=list(range(1, p.ndim)), keepdim=True
|
||||
)
|
||||
if step % size_update_period == size_update_period - 1:
|
||||
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
||||
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)),
|
||||
keepdim=True).sqrt())
|
||||
param_rms.copy_(
|
||||
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
||||
)
|
||||
if step > 0:
|
||||
# self._size_update() learns the overall scale on the
|
||||
# parameter, by shrinking or expanding it.
|
||||
self._size_update(group, scale_grads, p, state)
|
||||
|
||||
|
||||
if numel == 1:
|
||||
# For parameters with 1 element we just use regular Adam.
|
||||
# Updates delta.
|
||||
@ -411,12 +507,9 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
state["step"] = step + 1
|
||||
|
||||
|
||||
def _size_update(self,
|
||||
group: dict,
|
||||
scale_grads: Tensor,
|
||||
p: Tensor,
|
||||
state: dict) -> None:
|
||||
def _size_update(
|
||||
self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
|
||||
) -> None:
|
||||
"""
|
||||
Called only where p.numel() > 1, this updates the scale of the parameter.
|
||||
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
||||
@ -448,7 +541,8 @@ class ScaledAdam(BatchedOptimizer):
|
||||
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
|
||||
alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...)
|
||||
alpha=1 - beta2_corr,
|
||||
) # shape is (batch_size, 1, 1, ...)
|
||||
|
||||
# The 1st time we reach here is when size_step == 1.
|
||||
size_step = (step + 1) // size_update_period
|
||||
@ -458,10 +552,12 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
denom = scale_exp_avg_sq.sqrt() + eps
|
||||
|
||||
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom
|
||||
scale_step = (
|
||||
-size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
|
||||
)
|
||||
|
||||
is_too_small = (param_rms < param_min_rms)
|
||||
is_too_large = (param_rms > param_max_rms)
|
||||
is_too_small = param_rms < param_min_rms
|
||||
is_too_large = param_rms > param_max_rms
|
||||
|
||||
# when the param gets too small, just don't shrink it any further.
|
||||
scale_step.masked_fill_(is_too_small, 0.0)
|
||||
@ -471,11 +567,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# the factor of (1-beta1) relates to momentum.
|
||||
delta.add_(p * scale_step, alpha=(1 - beta1))
|
||||
|
||||
|
||||
def _step(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict):
|
||||
def _step(self, group: dict, p: Tensor, state: dict):
|
||||
"""
|
||||
This function does the core update of self.step(), in the case where the members of
|
||||
the batch have more than 1 element.
|
||||
@ -496,8 +588,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
step = state["step"]
|
||||
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||
value=(1-beta2))
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
||||
|
||||
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
|
||||
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
||||
@ -515,11 +606,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
delta.add_(grad * alpha)
|
||||
p.add_(delta)
|
||||
|
||||
|
||||
def _step_scalar(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict):
|
||||
def _step_scalar(self, group: dict, p: Tensor, state: dict):
|
||||
"""
|
||||
A simplified form of the core update for scalar tensors, where we cannot get a good
|
||||
estimate of the parameter rms.
|
||||
@ -531,8 +618,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
grad = p.grad
|
||||
|
||||
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||
value=1-beta2)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
|
||||
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
||||
# slower update at the start will help stability anyway.
|
||||
@ -545,7 +631,6 @@ class ScaledAdam(BatchedOptimizer):
|
||||
p.add_(delta)
|
||||
|
||||
|
||||
|
||||
class LRScheduler(object):
|
||||
"""
|
||||
Base-class for learning rate schedulers where the learning-rate depends on both the
|
||||
@ -555,18 +640,14 @@ class LRScheduler(object):
|
||||
def __init__(self, optimizer: Optimizer, verbose: bool = False):
|
||||
# Attach optimizer
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
raise TypeError(
|
||||
"{} is not an Optimizer".format(type(optimizer).__name__)
|
||||
)
|
||||
raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
|
||||
self.optimizer = optimizer
|
||||
self.verbose = verbose
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault("base_lr", group["lr"])
|
||||
|
||||
self.base_lrs = [
|
||||
group["base_lr"] for group in optimizer.param_groups
|
||||
]
|
||||
self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
|
||||
|
||||
self.epoch = 0
|
||||
self.batch = 0
|
||||
@ -682,11 +763,13 @@ class Eden(LRScheduler):
|
||||
factor = (
|
||||
(self.batch**2 + self.lr_batches**2) / self.lr_batches**2
|
||||
) ** -0.25 * (
|
||||
((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
|
||||
** -0.25
|
||||
((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25
|
||||
)
|
||||
warmup_factor = (
|
||||
1.0
|
||||
if self.batch >= self.warmup_batches
|
||||
else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
||||
)
|
||||
warmup_factor = (1.0 if self.batch >= self.warmup_batches
|
||||
else 0.5 + 0.5 * (self.batch / self.warmup_batches))
|
||||
|
||||
return [x * factor * warmup_factor for x in self.base_lrs]
|
||||
|
||||
@ -745,13 +828,14 @@ class Eve(Optimizer):
|
||||
parameters, if they fall below this we will stop applying weight decay.
|
||||
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
.. _Adam: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
.. _Decoupled Weight Decay Regularization:
|
||||
https://arxiv.org/abs/1711.05101
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
@ -766,17 +850,11 @@ class Eve(Optimizer):
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError(
|
||||
"Invalid beta parameter at index 0: {}".format(betas[0])
|
||||
)
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError(
|
||||
"Invalid beta parameter at index 1: {}".format(betas[1])
|
||||
)
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
if not 0 <= weight_decay <= 0.1:
|
||||
raise ValueError(
|
||||
"Invalid weight_decay value: {}".format(weight_decay)
|
||||
)
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
if not 0 < target_rms <= 10.0:
|
||||
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
||||
defaults = dict(
|
||||
@ -812,9 +890,7 @@ class Eve(Optimizer):
|
||||
# Perform optimization step
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"AdamW does not support sparse gradients"
|
||||
)
|
||||
raise RuntimeError("AdamW does not support sparse gradients")
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
@ -852,30 +928,31 @@ class Eve(Optimizer):
|
||||
if p.numel() > 1:
|
||||
# avoid applying this weight-decay on "scaling factors"
|
||||
# (which are scalar).
|
||||
is_above_target_rms = p.norm() > (
|
||||
target_rms * (p.numel() ** 0.5)
|
||||
)
|
||||
is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5))
|
||||
p.mul_(1 - (weight_decay * is_above_target_rms))
|
||||
|
||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
if random.random() < 0.0005:
|
||||
step = (exp_avg / denom) * step_size
|
||||
logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}")
|
||||
|
||||
logging.info(
|
||||
f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}"
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def _test_scaled_adam(hidden_dim: int):
|
||||
import timeit
|
||||
|
||||
from scaling import ScaledLinear
|
||||
|
||||
E = 100
|
||||
B = 4
|
||||
T = 2
|
||||
logging.info("in test_eve_cain")
|
||||
# device = torch.device('cuda')
|
||||
device = torch.device('cpu')
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
fix_random_seed(42)
|
||||
@ -889,21 +966,30 @@ def _test_scaled_adam(hidden_dim: int):
|
||||
fix_random_seed(42)
|
||||
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
||||
|
||||
m = torch.nn.Sequential(Linear(E, hidden_dim),
|
||||
m = torch.nn.Sequential(
|
||||
Linear(E, hidden_dim),
|
||||
torch.nn.PReLU(),
|
||||
Linear(hidden_dim, hidden_dim),
|
||||
torch.nn.PReLU(),
|
||||
Linear(hidden_dim, E),
|
||||
).to(device)
|
||||
|
||||
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
|
||||
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ]
|
||||
train_pairs = [
|
||||
(
|
||||
100.0
|
||||
* torch.randn(B, T, E, device=device, dtype=dtype)
|
||||
* input_magnitudes,
|
||||
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes,
|
||||
)
|
||||
for _ in range(20)
|
||||
]
|
||||
|
||||
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
|
||||
elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
|
||||
if iter == 0:
|
||||
optim = Eve(m.parameters(), lr=0.003)
|
||||
elif iter == 1:
|
||||
optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
|
||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
||||
|
||||
|
||||
start = timeit.default_timer()
|
||||
avg_loss = 0.0
|
||||
for epoch in range(180):
|
||||
@ -917,7 +1003,6 @@ def _test_scaled_adam(hidden_dim: int):
|
||||
# ) # allow 4 megabytes per sub-module
|
||||
# diagnostic = diagnostics.attach_diagnostics(m, opts)
|
||||
|
||||
|
||||
for n, (x, y) in enumerate(train_pairs):
|
||||
y_out = m(x)
|
||||
loss = ((y_out - y) ** 2).mean() * 100.0
|
||||
@ -935,7 +1020,9 @@ def _test_scaled_adam(hidden_dim: int):
|
||||
# scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
||||
# scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
||||
lr = scheduler.get_last_lr()[0]
|
||||
logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
||||
logging.info(
|
||||
f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}"
|
||||
) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
||||
loss.log().backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
@ -953,15 +1040,18 @@ def _test_scaled_adam(hidden_dim: int):
|
||||
logging.info(f"output_magnitudes = {output_magnitudes}")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
import subprocess
|
||||
s = subprocess.check_output("git status -uno .; git log -1; git diff HEAD .", shell=True)
|
||||
|
||||
s = subprocess.check_output(
|
||||
"git status -uno .; git log -1; git diff HEAD .", shell=True
|
||||
)
|
||||
logging.info(s)
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
hidden_dim = int(sys.argv[1])
|
||||
else:
|
||||
|
||||
@ -90,7 +90,6 @@ LRSchedulerType = Union[
|
||||
]
|
||||
|
||||
|
||||
|
||||
def get_adjusted_batch_count(
|
||||
params: AttributeDict) -> float:
|
||||
# returns the number of batches we would have used so far if we had used the reference
|
||||
@ -99,6 +98,7 @@ def get_adjusted_batch_count(
|
||||
(params.max_duration * params.world_size))
|
||||
|
||||
|
||||
|
||||
def set_batch_count(
|
||||
model: Union[nn.Module, DDP], batch_count: float
|
||||
) -> None:
|
||||
@ -856,7 +856,6 @@ def train_one_epoch(
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx % 10 == 0:
|
||||
set_batch_count(model, get_adjusted_batch_count(params))
|
||||
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
@ -879,7 +878,6 @@ def train_one_epoch(
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
|
||||
scaler.step(optimizer)
|
||||
@ -1058,9 +1056,12 @@ def run(rank, world_size, args):
|
||||
model = DDP(model, device_ids=[rank],
|
||||
find_unused_parameters=True)
|
||||
|
||||
optimizer = ScaledAdam(model.parameters(),
|
||||
optimizer = ScaledAdam(
|
||||
model.parameters(),
|
||||
lr=params.base_lr,
|
||||
clipping_scale=2.0)
|
||||
clipping_scale=2.0,
|
||||
parameters_names=[ [p[0] for p in model.named_parameters()] ],
|
||||
)
|
||||
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user