Merge debugging changes to optimizer.

This commit is contained in:
Daniel Povey 2022-12-20 13:01:50 +08:00
parent b546ac866c
commit 28cac1c2dc
2 changed files with 333 additions and 242 deletions

View File

@ -14,17 +14,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict import contextlib
from typing import List, Optional, Union, Tuple, List import logging
from lhotse.utils import fix_random_seed
import torch
from scaling import ActivationBalancer
import random import random
from collections import defaultdict
from typing import List, Optional, Tuple, Union
import torch
from lhotse.utils import fix_random_seed
from scaling import ActivationBalancer
from torch import Tensor from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
import logging
import contextlib
class BatchedOptimizer(Optimizer): class BatchedOptimizer(Optimizer):
@ -37,13 +37,12 @@ class BatchedOptimizer(Optimizer):
Args: Args:
params: params:
""" """
def __init__(self, params, defaults): def __init__(self, params, defaults):
super(BatchedOptimizer, self).__init__(params, defaults) super(BatchedOptimizer, self).__init__(params, defaults)
@contextlib.contextmanager @contextlib.contextmanager
def batched_params(self, param_group): def batched_params(self, param_group, group_params_names):
""" """
This function returns (technically, yields) a list of This function returns (technically, yields) a list of
of tuples (p, state), where of tuples (p, state), where
@ -65,106 +64,129 @@ class BatchedOptimizer(Optimizer):
you can do: you can do:
<code> <code>
with self.batched_params(group["params"]) as batches: with self.batched_params(group["params"]) as batches:
for p, state in batches: for p, state, p_names in batches:
... ...
</code> </code>
Args: Args:
group: a parameter group, which is a list of parameters; should be group: a parameter group, which is a list of parameters; should be
one of self.groups. one of self.param_groups.
group_params_names: name for each parameter in group,
which is List[str].
""" """
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter batches = defaultdict(
list
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
batches_names = defaultdict(
list
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
for p in param_group: assert len(param_group) == len(group_params_names)
for p, named_p in zip(param_group, group_params_names):
key = (str(p.dtype), *p.shape) key = (str(p.dtype), *p.shape)
batches[key].append(p) batches[key].append(p)
batches_names[key].append(named_p)
batches_names_keys = list(batches_names.keys())
sorted_idx = sorted(
range(len(batches_names)), key=lambda i: batches_names_keys[i]
)
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
stacked_params_dict = dict() stacked_params_dict = dict()
# turn batches into a list, in deterministic order. # turn batches into a list, in deterministic order.
batches = [ batches[key] for key in sorted(batches.keys()) ] # tuples will contain tuples of (stacked_param, state, stacked_params_names),
# pairs will contain pairs of (stacked_param, state), one for each batch # one for each batch in `batches`.
# in `batches`. tuples = []
pairs = []
for batch in batches: for batch, batch_names in zip(batches, batches_names):
p = batch[0] p = batch[0]
# we arbitrarily store the state in the # we arbitrarily store the state in the
# state corresponding to the 1st parameter in the # state corresponding to the 1st parameter in the
# group. class Optimizer will take care of saving/loading state. # group. class Optimizer will take care of saving/loading state.
state = self.state[p] state = self.state[p]
p_stacked = torch.stack(batch) p_stacked = torch.stack(batch)
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ]) grad = torch.stack(
[torch.zeros_like(p) if p.grad is None else p.grad for p in batch]
)
p_stacked.grad = grad p_stacked.grad = grad
stacked_params_dict[key] = p_stacked stacked_params_dict[key] = p_stacked
pairs.append((p_stacked, state)) tuples.append((p_stacked, state, batch_names))
yield pairs # <-- calling code will do the actual optimization here! yield tuples # <-- calling code will do the actual optimization here!
for ((stacked_params, _state), batch) in zip(pairs, batches): for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
for i, p in enumerate(batch): # batch is list of Parameter for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i]) p.copy_(stacked_params[i])
class ScaledAdam(BatchedOptimizer): class ScaledAdam(BatchedOptimizer):
""" """
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
proportional to the norm of that parameter; and also learn the scale of the parameter, proportional to the norm of that parameter; and also learn the scale of the parameter,
in log space, subject to upper and lower limits (as if we had factored each parameter as in log space, subject to upper and lower limits (as if we had factored each parameter as
param = underlying_param * log_scale.exp()) param = underlying_param * log_scale.exp())
Args: Args:
params: The parameters or param_groups to optimize (like other Optimizer subclasses) params: The parameters or param_groups to optimize (like other Optimizer subclasses)
lr: The learning rate. We will typically use a learning rate schedule that starts lr: The learning rate. We will typically use a learning rate schedule that starts
at 0.03 and decreases over time, i.e. much higher than other common at 0.03 and decreases over time, i.e. much higher than other common
optimizers. optimizers.
clipping_scale: (e.g. 2.0) clipping_scale: (e.g. 2.0)
A scale for gradient-clipping: if specified, the normalized gradients A scale for gradient-clipping: if specified, the normalized gradients
over the whole model will be clipped to have 2-norm equal to over the whole model will be clipped to have 2-norm equal to
`clipping_scale` times the median 2-norm over the most recent period `clipping_scale` times the median 2-norm over the most recent period
of `clipping_update_period` minibatches. By "normalized gradients", of `clipping_update_period` minibatches. By "normalized gradients",
we mean after multiplying by the rms parameter value for this tensor we mean after multiplying by the rms parameter value for this tensor
[for non-scalars]; this is appropriate because our update is scaled [for non-scalars]; this is appropriate because our update is scaled
by this quantity. by this quantity.
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
Must satisfy 0 < beta <= beta2 < 1. Must satisfy 0 < beta <= beta2 < 1.
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
scale of each parameter tensor and scalar parameters of the mode.. scale of each parameter tensor and scalar parameters of the mode..
If each parameter were decomposed If each parameter were decomposed
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
would be a the scaling factor on the learning rate of p_scale. would be a the scaling factor on the learning rate of p_scale.
eps: A general-purpose epsilon to prevent division by zero eps: A general-purpose epsilon to prevent division by zero
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll constrain the rms of each non-scalar learning the scale on the parameters (we'll constrain the rms of each non-scalar
parameter tensor to be >= this value) parameter tensor to be >= this value)
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll constrain the rms of each non-scalar learning the scale on the parameters (we'll constrain the rms of each non-scalar
parameter tensor to be <= this value) parameter tensor to be <= this value)
scalar_max: Maximum absolute value for scalar parameters (applicable if your scalar_max: Maximum absolute value for scalar parameters (applicable if your
model has any parameters with numel() == 1). model has any parameters with numel() == 1).
size_update_period: The periodicity, in steps, with which we update the size (scale) size_update_period: The periodicity, in steps, with which we update the size (scale)
of the parameter tensor. This is provided to save a little time of the parameter tensor. This is provided to save a little time
in the update. in the update.
clipping_update_period: if clipping_scale is specified, this is the period clipping_update_period: if clipping_scale is specified, this is the period
""" """
def __init__( def __init__(
self, self,
params, params,
lr=3e-02, lr=3e-02,
clipping_scale=None, clipping_scale=None,
betas=(0.9, 0.98), betas=(0.9, 0.98),
scalar_lr_scale=0.1, scalar_lr_scale=0.1,
eps=1.0e-08, eps=1.0e-08,
param_min_rms=1.0e-05, param_min_rms=1.0e-05,
param_max_rms=3.0, param_max_rms=3.0,
scalar_max=10.0, scalar_max=10.0,
size_update_period=4, size_update_period=4,
clipping_update_period=100, clipping_update_period=100,
parameters_names=None,
show_dominant_parameters=True,
): ):
assert parameters_names is not None, (
"Please prepare parameters_names,"
"which is a List[List[str]]. Each List[str] is for a group"
"and each str is for a parameter"
)
defaults = dict( defaults = dict(
lr=lr, lr=lr,
clipping_scale=clipping_scale, clipping_scale=clipping_scale,
@ -179,11 +201,13 @@ class ScaledAdam(BatchedOptimizer):
) )
super(ScaledAdam, self).__init__(params, defaults) super(ScaledAdam, self).__init__(params, defaults)
assert len(self.param_groups) == len(parameters_names)
self.parameters_names = parameters_names
self.show_dominant_parameters = show_dominant_parameters
def __setstate__(self, state): def __setstate__(self, state):
super(ScaledAdam, self).__setstate__(state) super(ScaledAdam, self).__setstate__(state)
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
@ -198,20 +222,23 @@ class ScaledAdam(BatchedOptimizer):
loss = closure() loss = closure()
batch = True batch = True
for group in self.param_groups:
with self.batched_params(group["params"]) as batches: for group, group_params_names in zip(self.param_groups, self.parameters_names):
with self.batched_params(group["params"], group_params_names) as batches:
# batches is list of pairs (stacked_param, state). stacked_param is like # batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to # a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim. # a stacking dim, it is not a real dim.
if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized if (
len(batches[0][1]) == 0
): # if len(first state) == 0: not yet initialized
clipping_scale = 1 clipping_scale = 1
else: else:
clipping_scale = self._get_clipping_scale(group, batches) clipping_scale = self._get_clipping_scale(group, batches)
for p, state in batches: for p, state, _ in batches:
# Perform optimization step. # Perform optimization step.
# grad is not going to be None, we handled that when creating the batches. # grad is not going to be None, we handled that when creating the batches.
grad = p.grad grad = p.grad
@ -225,13 +252,9 @@ class ScaledAdam(BatchedOptimizer):
self._step_one_batch(group, p, state, clipping_scale) self._step_one_batch(group, p, state, clipping_scale)
return loss return loss
def _init_state(self, def _init_state(self, group: dict, p: Tensor, state: dict):
group: dict,
p: Tensor,
state: dict):
""" """
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
is actually the batch dimension, corresponding to batched-together is actually the batch dimension, corresponding to batched-together
@ -247,7 +270,7 @@ class ScaledAdam(BatchedOptimizer):
state["step"] = 0 state["step"] = 0
kwargs = {'device':p.device, 'dtype':p.dtype} kwargs = {"device": p.device, "dtype": p.dtype}
# 'delta' implements conventional momentum. There are # 'delta' implements conventional momentum. There are
# several different kinds of update going on, so rather than # several different kinds of update going on, so rather than
@ -255,48 +278,46 @@ class ScaledAdam(BatchedOptimizer):
# parameter-change "delta", which combines all forms of # parameter-change "delta", which combines all forms of
# update. this is equivalent to how it's done in Adam, # update. this is equivalent to how it's done in Adam,
# except for the first few steps. # except for the first few steps.
state["delta"] = torch.zeros_like( state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
p, memory_format=torch.preserve_format
)
batch_size = p.shape[0] batch_size = p.shape[0]
numel = p.numel() // batch_size numel = p.numel() // batch_size
numel = p.numel() numel = p.numel()
if numel > 1: if numel > 1:
# "param_rms" just periodically records the scalar root-mean-square value of # "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor. # the parameter tensor.
# it has a shape like (batch_size, 1, 1, 1, 1) # it has a shape like (batch_size, 1, 1, 1, 1)
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
keepdim=True).sqrt()
state["param_rms"] = param_rms state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, state["scale_grads"] = torch.zeros(
**kwargs) size_update_period, *param_rms.shape, **kwargs
)
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam. # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
state["exp_avg_sq"] = torch.zeros_like( state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
p, memory_format=torch.preserve_format
)
def _get_clipping_scale(self, def _get_clipping_scale(
group: dict, self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
pairs: List[Tuple[Tensor, dict]]) -> float: ) -> float:
""" """
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
by this amount before applying the rest of the update. by this amount before applying the rest of the update.
Args: Args:
group: the parameter group, an item in self.param_groups group: the parameter group, an item in self.param_groups
pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad tuples: a list of tuples of (param, state, param_names)
(1st dim is batch dim) and state is the state-dict where optimization parameters are kept. where param is a batched set of parameters,
with a .grad (1st dim is batch dim)
and state is the state-dict where optimization parameters are kept.
param_names is a List[str] while each str is name for a parameter
in batched set of parameters "param".
""" """
assert len(pairs) >= 1 assert len(tuples) >= 1
clipping_scale = group["clipping_scale"] clipping_scale = group["clipping_scale"]
(first_p, first_state) = pairs[0] (first_p, first_state, _) = tuples[0]
step = first_state["step"] step = first_state["step"]
if clipping_scale is None or step == 0: if clipping_scale is None or step == 0:
# no clipping. return early on step == 0 because the other # no clipping. return early on step == 0 because the other
@ -305,7 +326,7 @@ class ScaledAdam(BatchedOptimizer):
clipping_update_period = group["clipping_update_period"] clipping_update_period = group["clipping_update_period"]
tot_sumsq = torch.tensor(0.0, device=first_p.device) tot_sumsq = torch.tensor(0.0, device=first_p.device)
for (p, state) in pairs: for (p, state, param_names) in tuples:
grad = p.grad grad = p.grad
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError( raise RuntimeError(
@ -314,57 +335,131 @@ class ScaledAdam(BatchedOptimizer):
if p.numel() == p.shape[0]: # a batch of scalars if p.numel() == p.shape[0]: # a batch of scalars
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
else: else:
tot_sumsq += ((grad * state["param_rms"])**2).sum() tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
tot_norm = tot_sumsq.sqrt() tot_norm = tot_sumsq.sqrt()
if not "model_norms" in first_state: if "model_norms" not in first_state:
first_state["model_norms"] = torch.zeros(clipping_update_period, first_state["model_norms"] = torch.zeros(
device=p.device) clipping_update_period, device=p.device
)
first_state["model_norms"][step % clipping_update_period] = tot_norm first_state["model_norms"][step % clipping_update_period] = tot_norm
if step % clipping_update_period == 0: if step % clipping_update_period == 0:
# Print some stats. # Print some stats.
# We don't reach here if step == 0 because we would have returned # We don't reach here if step == 0 because we would have returned
# above. # above.
sorted_norms = first_state["model_norms"].sort()[0].to('cpu') sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
quartiles = [] quartiles = []
for n in range(0, 5): for n in range(0, 5):
index = min(clipping_update_period - 1, index = min(
(clipping_update_period // 4) * n) clipping_update_period - 1, (clipping_update_period // 4) * n
)
quartiles.append(sorted_norms[index].item()) quartiles.append(sorted_norms[index].item())
median = quartiles[2] median = quartiles[2]
threshold = clipping_scale * median threshold = clipping_scale * median
first_state["model_norm_threshold"] = threshold first_state["model_norm_threshold"] = threshold
percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period percent_clipped = (
if "num_clipped" in first_state else 0.0) first_state["num_clipped"] * 100.0 / clipping_update_period
if "num_clipped" in first_state
else 0.0
)
first_state["num_clipped"] = 0 first_state["num_clipped"] = 0
quartiles = ' '.join([ '%.3e' % x for x in quartiles ]) quartiles = " ".join(["%.3e" % x for x in quartiles])
logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " logging.info(
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}") f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
)
if step < clipping_update_period: if step < clipping_update_period:
return 1.0 # We have not yet estimated a norm to clip to. return 1.0 # We have not yet estimated a norm to clip to.
else: else:
try: try:
model_norm_threshold = first_state["model_norm_threshold"] model_norm_threshold = first_state["model_norm_threshold"]
except: except KeyError:
logging.info("Warning: model_norm_threshold not in state: possibly " logging.info(
"you changed config when restarting, adding clipping_scale option?") "Warning: model_norm_threshold not in state: possibly "
"you changed config when restarting, adding clipping_scale option?"
)
return 1.0 return 1.0
ans = min(1.0,(model_norm_threshold / (tot_norm + 1.0e-20)).item()) ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
if ans < 1.0: if ans < 1.0:
first_state["num_clipped"] += 1 first_state["num_clipped"] += 1
if ans < 0.1: if ans < 0.1:
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") logging.warn(
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
)
if self.show_dominant_parameters:
assert p.shape[0] == len(param_names)
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
return ans return ans
def _show_gradient_dominating_parameter(
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
):
"""
Show information of parameter wihch dominanting tot_sumsq.
def _step_one_batch(self, Args:
group: dict, tuples: a list of tuples of (param, state, param_names)
p: Tensor, where param is a batched set of parameters,
state: dict, with a .grad (1st dim is batch dim)
clipping_scale: float): and state is the state-dict where optimization parameters are kept.
param_names is a List[str] while each str is name for a parameter
in batched set of parameters "param".
tot_sumsq: sumsq of all parameters. Though it's could be calculated
from tuples, we still pass it to save some time.
"""
all_sumsq_orig = {}
for (p, state, batch_param_names) in tuples:
# p is a stacked batch parameters.
batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
batch_sumsq_orig = batch_grad**2
# Dummpy values used by following `zip` statement.
batch_rms_orig = torch.ones(p.shape[0])
else:
batch_rms_orig = state["param_rms"]
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
dim=list(range(1, batch_grad.ndim))
)
for name, sumsq_orig, rms, grad in zip(
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
):
proportion_orig = sumsq_orig / tot_sumsq
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
assert torch.isclose(
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
torch.tensor(1.0),
)
sorted_by_proportion = {
k: v
for k, v in sorted(
all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True
)
}
dominant_param_name = next(iter(sorted_by_proportion))
(
dominant_proportion,
dominant_sumsq,
dominant_rms,
dominant_grad,
) = sorted_by_proportion[dominant_param_name]
logging.info(
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
f" with proportion {dominant_proportion:.2f},"
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
f"={dominant_sumsq:.3e},"
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
)
def _step_one_batch(
self, group: dict, p: Tensor, state: dict, clipping_scale: float
):
""" """
Do the step for one parameter, which is actually going to be a batch of Do the step for one parameter, which is actually going to be a batch of
`real` parameters, with dim 0 as the batch dim. `real` parameters, with dim 0 as the batch dim.
@ -391,17 +486,18 @@ class ScaledAdam(BatchedOptimizer):
# Update the size/scale of p, and set param_rms # Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"] scale_grads = state["scale_grads"]
scale_grads[step % size_update_period] = (p * grad).sum( scale_grads[step % size_update_period] = (p * grad).sum(
dim=list(range(1, p.ndim)), keepdim=True) dim=list(range(1, p.ndim)), keepdim=True
)
if step % size_update_period == size_update_period - 1: if step % size_update_period == size_update_period - 1:
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), param_rms.copy_(
keepdim=True).sqrt()) (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
)
if step > 0: if step > 0:
# self._size_update() learns the overall scale on the # self._size_update() learns the overall scale on the
# parameter, by shrinking or expanding it. # parameter, by shrinking or expanding it.
self._size_update(group, scale_grads, p, state) self._size_update(group, scale_grads, p, state)
if numel == 1: if numel == 1:
# For parameters with 1 element we just use regular Adam. # For parameters with 1 element we just use regular Adam.
# Updates delta. # Updates delta.
@ -411,24 +507,21 @@ class ScaledAdam(BatchedOptimizer):
state["step"] = step + 1 state["step"] = step + 1
def _size_update(
def _size_update(self, self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
group: dict, ) -> None:
scale_grads: Tensor,
p: Tensor,
state: dict) -> None:
""" """
Called only where p.numel() > 1, this updates the scale of the parameter. Called only where p.numel() > 1, this updates the scale of the parameter.
If we imagine: p = underlying_param * scale.exp(), and we are doing If we imagine: p = underlying_param * scale.exp(), and we are doing
gradient descent on underlying param and on scale, this function does the update gradient descent on underlying param and on scale, this function does the update
on `scale`. on `scale`.
Args: Args:
group: dict to look up configuration values group: dict to look up configuration values
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
grads w.r.t. the scales. grads w.r.t. the scales.
p: The parameter to update p: The parameter to update
state: The state-dict of p state: The state-dict of p
""" """
param_rms = state["param_rms"] param_rms = state["param_rms"]
@ -443,25 +536,28 @@ class ScaledAdam(BatchedOptimizer):
size_update_period = scale_grads.shape[0] size_update_period = scale_grads.shape[0]
# correct beta2 for the size update period: we will have # correct beta2 for the size update period: we will have
# faster decay at this level. # faster decay at this level.
beta2_corr = beta2 ** size_update_period beta2_corr = beta2**size_update_period
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq.mul_(beta2_corr).add_( scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period` (scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...) alpha=1 - beta2_corr,
) # shape is (batch_size, 1, 1, ...)
# The 1st time we reach here is when size_step == 1. # The 1st time we reach here is when size_step == 1.
size_step = (step + 1) // size_update_period size_step = (step + 1) // size_update_period
bias_correction2 = 1 - beta2_corr ** size_step bias_correction2 = 1 - beta2_corr**size_step
# we don't bother with bias_correction1; this will help prevent divergence # we don't bother with bias_correction1; this will help prevent divergence
# at the start of training. # at the start of training.
denom = scale_exp_avg_sq.sqrt() + eps denom = scale_exp_avg_sq.sqrt() + eps
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom scale_step = (
-size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
)
is_too_small = (param_rms < param_min_rms) is_too_small = param_rms < param_min_rms
is_too_large = (param_rms > param_max_rms) is_too_large = param_rms > param_max_rms
# when the param gets too small, just don't shrink it any further. # when the param gets too small, just don't shrink it any further.
scale_step.masked_fill_(is_too_small, 0.0) scale_step.masked_fill_(is_too_small, 0.0)
@ -469,13 +565,9 @@ class ScaledAdam(BatchedOptimizer):
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
delta = state["delta"] delta = state["delta"]
# the factor of (1-beta1) relates to momentum. # the factor of (1-beta1) relates to momentum.
delta.add_(p * scale_step, alpha=(1-beta1)) delta.add_(p * scale_step, alpha=(1 - beta1))
def _step(self, group: dict, p: Tensor, state: dict):
def _step(self,
group: dict,
p: Tensor,
state: dict):
""" """
This function does the core update of self.step(), in the case where the members of This function does the core update of self.step(), in the case where the members of
the batch have more than 1 element. the batch have more than 1 element.
@ -496,8 +588,7 @@ class ScaledAdam(BatchedOptimizer):
step = state["step"] step = state["step"]
exp_avg_sq = state["exp_avg_sq"] exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
value=(1-beta2))
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
bias_correction2 = 1 - beta2 ** (this_step + 1) bias_correction2 = 1 - beta2 ** (this_step + 1)
@ -509,17 +600,13 @@ class ScaledAdam(BatchedOptimizer):
denom += eps denom += eps
grad = grad / denom grad = grad / denom
alpha = -lr * (1-beta1) * state["param_rms"].clamp(min=param_min_rms) alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
delta = state["delta"] delta = state["delta"]
delta.add_(grad * alpha) delta.add_(grad * alpha)
p.add_(delta) p.add_(delta)
def _step_scalar(self, group: dict, p: Tensor, state: dict):
def _step_scalar(self,
group: dict,
p: Tensor,
state: dict):
""" """
A simplified form of the core update for scalar tensors, where we cannot get a good A simplified form of the core update for scalar tensors, where we cannot get a good
estimate of the parameter rms. estimate of the parameter rms.
@ -531,8 +618,7 @@ class ScaledAdam(BatchedOptimizer):
grad = p.grad grad = p.grad
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
value=1-beta2)
# bias_correction2 is like in Adam. Don't bother with bias_correction1; # bias_correction2 is like in Adam. Don't bother with bias_correction1;
# slower update at the start will help stability anyway. # slower update at the start will help stability anyway.
@ -540,12 +626,11 @@ class ScaledAdam(BatchedOptimizer):
denom = (exp_avg_sq / bias_correction2).sqrt() + eps denom = (exp_avg_sq / bias_correction2).sqrt() + eps
delta = state["delta"] delta = state["delta"]
delta.add_(grad / denom, alpha=-lr*(1-beta1)) delta.add_(grad / denom, alpha=-lr * (1 - beta1))
p.clamp_(min=-scalar_max, max=scalar_max) p.clamp_(min=-scalar_max, max=scalar_max)
p.add_(delta) p.add_(delta)
class LRScheduler(object): class LRScheduler(object):
""" """
Base-class for learning rate schedulers where the learning-rate depends on both the Base-class for learning rate schedulers where the learning-rate depends on both the
@ -555,18 +640,14 @@ class LRScheduler(object):
def __init__(self, optimizer: Optimizer, verbose: bool = False): def __init__(self, optimizer: Optimizer, verbose: bool = False):
# Attach optimizer # Attach optimizer
if not isinstance(optimizer, Optimizer): if not isinstance(optimizer, Optimizer):
raise TypeError( raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
"{} is not an Optimizer".format(type(optimizer).__name__)
)
self.optimizer = optimizer self.optimizer = optimizer
self.verbose = verbose self.verbose = verbose
for group in optimizer.param_groups: for group in optimizer.param_groups:
group.setdefault("base_lr", group["lr"]) group.setdefault("base_lr", group["lr"])
self.base_lrs = [ self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
group["base_lr"] for group in optimizer.param_groups
]
self.epoch = 0 self.epoch = 0
self.batch = 0 self.batch = 0
@ -680,13 +761,15 @@ class Eden(LRScheduler):
def get_lr(self): def get_lr(self):
factor = ( factor = (
(self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 (self.batch**2 + self.lr_batches**2) / self.lr_batches**2
) ** -0.25 * ( ) ** -0.25 * (
((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25
** -0.25 )
warmup_factor = (
1.0
if self.batch >= self.warmup_batches
else 0.5 + 0.5 * (self.batch / self.warmup_batches)
) )
warmup_factor = (1.0 if self.batch >= self.warmup_batches
else 0.5 + 0.5 * (self.batch / self.warmup_batches))
return [x * factor * warmup_factor for x in self.base_lrs] return [x * factor * warmup_factor for x in self.base_lrs]
@ -745,13 +828,14 @@ class Eve(Optimizer):
parameters, if they fall below this we will stop applying weight decay. parameters, if they fall below this we will stop applying weight decay.
.. _Adam\: A Method for Stochastic Optimization: .. _Adam: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization: .. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101 https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond: .. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ https://openreview.net/forum?id=ryQu7f-RZ
""" """
def __init__( def __init__(
self, self,
params, params,
@ -766,17 +850,11 @@ class Eve(Optimizer):
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0: if not 0.0 <= betas[0] < 1.0:
raise ValueError( raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
"Invalid beta parameter at index 0: {}".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0: if not 0.0 <= betas[1] < 1.0:
raise ValueError( raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
"Invalid beta parameter at index 1: {}".format(betas[1])
)
if not 0 <= weight_decay <= 0.1: if not 0 <= weight_decay <= 0.1:
raise ValueError( raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
"Invalid weight_decay value: {}".format(weight_decay)
)
if not 0 < target_rms <= 10.0: if not 0 < target_rms <= 10.0:
raise ValueError("Invalid target_rms value: {}".format(target_rms)) raise ValueError("Invalid target_rms value: {}".format(target_rms))
defaults = dict( defaults = dict(
@ -812,9 +890,7 @@ class Eve(Optimizer):
# Perform optimization step # Perform optimization step
grad = p.grad grad = p.grad
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError( raise RuntimeError("AdamW does not support sparse gradients")
"AdamW does not support sparse gradients"
)
state = self.state[p] state = self.state[p]
@ -841,7 +917,7 @@ class Eve(Optimizer):
# Decay the first and second moment running average coefficient # Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_(
group["eps"] group["eps"]
) )
@ -852,30 +928,31 @@ class Eve(Optimizer):
if p.numel() > 1: if p.numel() > 1:
# avoid applying this weight-decay on "scaling factors" # avoid applying this weight-decay on "scaling factors"
# (which are scalar). # (which are scalar).
is_above_target_rms = p.norm() > ( is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5))
target_rms * (p.numel() ** 0.5)
)
p.mul_(1 - (weight_decay * is_above_target_rms)) p.mul_(1 - (weight_decay * is_above_target_rms))
p.addcdiv_(exp_avg, denom, value=-step_size) p.addcdiv_(exp_avg, denom, value=-step_size)
if random.random() < 0.0005: if random.random() < 0.0005:
step = (exp_avg/denom) * step_size step = (exp_avg / denom) * step_size
logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}") logging.info(
f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}"
)
return loss return loss
def _test_scaled_adam(hidden_dim: int): def _test_scaled_adam(hidden_dim: int):
import timeit import timeit
from scaling import ScaledLinear from scaling import ScaledLinear
E = 100 E = 100
B = 4 B = 4
T = 2 T = 2
logging.info("in test_eve_cain") logging.info("in test_eve_cain")
#device = torch.device('cuda') # device = torch.device('cuda')
device = torch.device('cpu') device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
fix_random_seed(42) fix_random_seed(42)
@ -889,79 +966,92 @@ def _test_scaled_adam(hidden_dim: int):
fix_random_seed(42) fix_random_seed(42)
Linear = torch.nn.Linear if iter == 0 else ScaledLinear Linear = torch.nn.Linear if iter == 0 else ScaledLinear
m = torch.nn.Sequential(Linear(E, hidden_dim), m = torch.nn.Sequential(
torch.nn.PReLU(), Linear(E, hidden_dim),
Linear(hidden_dim, hidden_dim), torch.nn.PReLU(),
torch.nn.PReLU(), Linear(hidden_dim, hidden_dim),
Linear(hidden_dim, E), torch.nn.PReLU(),
).to(device) Linear(hidden_dim, E),
).to(device)
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, train_pairs = [
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] (
100.0
* torch.randn(B, T, E, device=device, dtype=dtype)
* input_magnitudes,
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes,
)
for _ in range(20)
]
if iter == 0: optim = Eve(m.parameters(), lr=0.003) if iter == 0:
elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) optim = Eve(m.parameters(), lr=0.003)
elif iter == 1:
optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
start = timeit.default_timer() start = timeit.default_timer()
avg_loss = 0.0 avg_loss = 0.0
for epoch in range(180): for epoch in range(180):
scheduler.step_epoch() scheduler.step_epoch()
#if epoch == 100 and iter in [2,3]: # if epoch == 100 and iter in [2,3]:
# optim.reset_speedup() # check it doesn't crash. # optim.reset_speedup() # check it doesn't crash.
#if epoch == 130: # if epoch == 130:
# opts = diagnostics.TensorDiagnosticOptions( # opts = diagnostics.TensorDiagnosticOptions(
# 2 ** 22 # 2 ** 22
# ) # allow 4 megabytes per sub-module # ) # allow 4 megabytes per sub-module
# diagnostic = diagnostics.attach_diagnostics(m, opts) # diagnostic = diagnostics.attach_diagnostics(m, opts)
for n, (x, y) in enumerate(train_pairs):
for n, (x,y) in enumerate(train_pairs):
y_out = m(x) y_out = m(x)
loss = ((y_out - y)**2).mean() * 100.0 loss = ((y_out - y) ** 2).mean() * 100.0
if epoch == 0 and n == 0: if epoch == 0 and n == 0:
avg_loss = loss.item() avg_loss = loss.item()
else: else:
avg_loss = 0.98 * avg_loss + 0.02 * loss.item() avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
if n == 0 and epoch % 5 == 0: if n == 0 and epoch % 5 == 0:
#norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
#norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
#norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
#norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
#scale1 = '%.2e' % (m[0].weight_scale.exp().item()) # scale1 = '%.2e' % (m[0].weight_scale.exp().item())
#scale1b = '%.2e' % (m[0].bias_scale.exp().item()) # scale1b = '%.2e' % (m[0].bias_scale.exp().item())
#scale2 = '%.2e' % (m[2].weight_scale.exp().item()) # scale2 = '%.2e' % (m[2].weight_scale.exp().item())
#scale2b = '%.2e' % (m[2].bias_scale.exp().item()) # scale2b = '%.2e' % (m[2].bias_scale.exp().item())
lr = scheduler.get_last_lr()[0] lr = scheduler.get_last_lr()[0]
logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} logging.info(
f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}"
) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
loss.log().backward() loss.log().backward()
optim.step() optim.step()
optim.zero_grad() optim.zero_grad()
scheduler.step_batch() scheduler.step_batch()
#diagnostic.print_diagnostics() # diagnostic.print_diagnostics()
stop = timeit.default_timer() stop = timeit.default_timer()
logging.info(f"Iter={iter}, Time taken: {stop - start}") logging.info(f"Iter={iter}, Time taken: {stop - start}")
logging.info(f"last lr = {scheduler.get_last_lr()}") logging.info(f"last lr = {scheduler.get_last_lr()}")
#logging.info("state dict = ", scheduler.state_dict()) # logging.info("state dict = ", scheduler.state_dict())
#logging.info("optim state_dict = ", optim.state_dict()) # logging.info("optim state_dict = ", optim.state_dict())
logging.info(f"input_magnitudes = {input_magnitudes}") logging.info(f"input_magnitudes = {input_magnitudes}")
logging.info(f"output_magnitudes = {output_magnitudes}") logging.info(f"output_magnitudes = {output_magnitudes}")
if __name__ == "__main__": if __name__ == "__main__":
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
import subprocess import subprocess
s = subprocess.check_output("git status -uno .; git log -1; git diff HEAD .", shell=True)
s = subprocess.check_output(
"git status -uno .; git log -1; git diff HEAD .", shell=True
)
logging.info(s) logging.info(s)
import sys import sys
if len(sys.argv) > 1: if len(sys.argv) > 1:
hidden_dim = int(sys.argv[1]) hidden_dim = int(sys.argv[1])
else: else:

View File

@ -90,7 +90,6 @@ LRSchedulerType = Union[
] ]
def get_adjusted_batch_count( def get_adjusted_batch_count(
params: AttributeDict) -> float: params: AttributeDict) -> float:
# returns the number of batches we would have used so far if we had used the reference # returns the number of batches we would have used so far if we had used the reference
@ -99,6 +98,7 @@ def get_adjusted_batch_count(
(params.max_duration * params.world_size)) (params.max_duration * params.world_size))
def set_batch_count( def set_batch_count(
model: Union[nn.Module, DDP], batch_count: float model: Union[nn.Module, DDP], batch_count: float
) -> None: ) -> None:
@ -856,7 +856,6 @@ def train_one_epoch(
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0: if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params)) set_batch_count(model, get_adjusted_batch_count(params))
if batch_idx < cur_batch_idx: if batch_idx < cur_batch_idx:
continue continue
cur_batch_idx = batch_idx cur_batch_idx = batch_idx
@ -879,7 +878,6 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
scaler.scale(loss).backward() scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train) scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer) scaler.step(optimizer)
@ -1058,9 +1056,12 @@ def run(rank, world_size, args):
model = DDP(model, device_ids=[rank], model = DDP(model, device_ids=[rank],
find_unused_parameters=True) find_unused_parameters=True)
optimizer = ScaledAdam(model.parameters(), optimizer = ScaledAdam(
lr=params.base_lr, model.parameters(),
clipping_scale=2.0) lr=params.base_lr,
clipping_scale=2.0,
parameters_names=[ [p[0] for p in model.named_parameters()] ],
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)