Merge debugging changes to optimizer.

This commit is contained in:
Daniel Povey 2022-12-20 13:01:50 +08:00
parent b546ac866c
commit 28cac1c2dc
2 changed files with 333 additions and 242 deletions

View File

@ -14,17 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import List, Optional, Union, Tuple, List
from lhotse.utils import fix_random_seed
import torch
from scaling import ActivationBalancer
import contextlib
import logging
import random
from collections import defaultdict
from typing import List, Optional, Tuple, Union
import torch
from lhotse.utils import fix_random_seed
from scaling import ActivationBalancer
from torch import Tensor
from torch.optim import Optimizer
import logging
import contextlib
class BatchedOptimizer(Optimizer):
@ -37,13 +37,12 @@ class BatchedOptimizer(Optimizer):
Args:
params:
"""
def __init__(self, params, defaults):
super(BatchedOptimizer, self).__init__(params, defaults)
@contextlib.contextmanager
def batched_params(self, param_group):
def batched_params(self, param_group, group_params_names):
"""
This function returns (technically, yields) a list of
of tuples (p, state), where
@ -65,48 +64,64 @@ class BatchedOptimizer(Optimizer):
you can do:
<code>
with self.batched_params(group["params"]) as batches:
for p, state in batches:
for p, state, p_names in batches:
...
</code>
Args:
group: a parameter group, which is a list of parameters; should be
one of self.groups.
one of self.param_groups.
group_params_names: name for each parameter in group,
which is List[str].
"""
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
batches = defaultdict(
list
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
batches_names = defaultdict(
list
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
for p in param_group:
assert len(param_group) == len(group_params_names)
for p, named_p in zip(param_group, group_params_names):
key = (str(p.dtype), *p.shape)
batches[key].append(p)
batches_names[key].append(named_p)
batches_names_keys = list(batches_names.keys())
sorted_idx = sorted(
range(len(batches_names)), key=lambda i: batches_names_keys[i]
)
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
stacked_params_dict = dict()
# turn batches into a list, in deterministic order.
batches = [ batches[key] for key in sorted(batches.keys()) ]
# pairs will contain pairs of (stacked_param, state), one for each batch
# in `batches`.
pairs = []
# tuples will contain tuples of (stacked_param, state, stacked_params_names),
# one for each batch in `batches`.
tuples = []
for batch in batches:
for batch, batch_names in zip(batches, batches_names):
p = batch[0]
# we arbitrarily store the state in the
# state corresponding to the 1st parameter in the
# group. class Optimizer will take care of saving/loading state.
state = self.state[p]
p_stacked = torch.stack(batch)
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ])
grad = torch.stack(
[torch.zeros_like(p) if p.grad is None else p.grad for p in batch]
)
p_stacked.grad = grad
stacked_params_dict[key] = p_stacked
pairs.append((p_stacked, state))
tuples.append((p_stacked, state, batch_names))
yield pairs # <-- calling code will do the actual optimization here!
yield tuples # <-- calling code will do the actual optimization here!
for ((stacked_params, _state), batch) in zip(pairs, batches):
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i])
class ScaledAdam(BatchedOptimizer):
"""
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
@ -149,6 +164,7 @@ class ScaledAdam(BatchedOptimizer):
in the update.
clipping_update_period: if clipping_scale is specified, this is the period
"""
def __init__(
self,
params,
@ -162,9 +178,15 @@ class ScaledAdam(BatchedOptimizer):
scalar_max=10.0,
size_update_period=4,
clipping_update_period=100,
parameters_names=None,
show_dominant_parameters=True,
):
assert parameters_names is not None, (
"Please prepare parameters_names,"
"which is a List[List[str]]. Each List[str] is for a group"
"and each str is for a parameter"
)
defaults = dict(
lr=lr,
clipping_scale=clipping_scale,
@ -179,11 +201,13 @@ class ScaledAdam(BatchedOptimizer):
)
super(ScaledAdam, self).__init__(params, defaults)
assert len(self.param_groups) == len(parameters_names)
self.parameters_names = parameters_names
self.show_dominant_parameters = show_dominant_parameters
def __setstate__(self, state):
super(ScaledAdam, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
@ -198,20 +222,23 @@ class ScaledAdam(BatchedOptimizer):
loss = closure()
batch = True
for group in self.param_groups:
with self.batched_params(group["params"]) as batches:
for group, group_params_names in zip(self.param_groups, self.parameters_names):
with self.batched_params(group["params"], group_params_names) as batches:
# batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim.
if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
if (
len(batches[0][1]) == 0
): # if len(first state) == 0: not yet initialized
clipping_scale = 1
else:
clipping_scale = self._get_clipping_scale(group, batches)
for p, state in batches:
for p, state, _ in batches:
# Perform optimization step.
# grad is not going to be None, we handled that when creating the batches.
grad = p.grad
@ -225,13 +252,9 @@ class ScaledAdam(BatchedOptimizer):
self._step_one_batch(group, p, state, clipping_scale)
return loss
def _init_state(self,
group: dict,
p: Tensor,
state: dict):
def _init_state(self, group: dict, p: Tensor, state: dict):
"""
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
is actually the batch dimension, corresponding to batched-together
@ -247,7 +270,7 @@ class ScaledAdam(BatchedOptimizer):
state["step"] = 0
kwargs = {'device':p.device, 'dtype':p.dtype}
kwargs = {"device": p.device, "dtype": p.dtype}
# 'delta' implements conventional momentum. There are
# several different kinds of update going on, so rather than
@ -255,48 +278,46 @@ class ScaledAdam(BatchedOptimizer):
# parameter-change "delta", which combines all forms of
# update. this is equivalent to how it's done in Adam,
# except for the first few steps.
state["delta"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
batch_size = p.shape[0]
numel = p.numel() // batch_size
numel = p.numel()
if numel > 1:
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
# it has a shape like (batch_size, 1, 1, 1, 1)
param_rms = (p**2).mean(dim=list(range(1, p.ndim)),
keepdim=True).sqrt()
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape,
**kwargs)
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
state["scale_grads"] = torch.zeros(
size_update_period, *param_rms.shape, **kwargs
)
def _get_clipping_scale(self,
group: dict,
pairs: List[Tuple[Tensor, dict]]) -> float:
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
def _get_clipping_scale(
self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
) -> float:
"""
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
by this amount before applying the rest of the update.
Args:
group: the parameter group, an item in self.param_groups
pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad
(1st dim is batch dim) and state is the state-dict where optimization parameters are kept.
tuples: a list of tuples of (param, state, param_names)
where param is a batched set of parameters,
with a .grad (1st dim is batch dim)
and state is the state-dict where optimization parameters are kept.
param_names is a List[str] while each str is name for a parameter
in batched set of parameters "param".
"""
assert len(pairs) >= 1
assert len(tuples) >= 1
clipping_scale = group["clipping_scale"]
(first_p, first_state) = pairs[0]
(first_p, first_state, _) = tuples[0]
step = first_state["step"]
if clipping_scale is None or step == 0:
# no clipping. return early on step == 0 because the other
@ -305,7 +326,7 @@ class ScaledAdam(BatchedOptimizer):
clipping_update_period = group["clipping_update_period"]
tot_sumsq = torch.tensor(0.0, device=first_p.device)
for (p, state) in pairs:
for (p, state, param_names) in tuples:
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
@ -317,54 +338,128 @@ class ScaledAdam(BatchedOptimizer):
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
tot_norm = tot_sumsq.sqrt()
if not "model_norms" in first_state:
first_state["model_norms"] = torch.zeros(clipping_update_period,
device=p.device)
if "model_norms" not in first_state:
first_state["model_norms"] = torch.zeros(
clipping_update_period, device=p.device
)
first_state["model_norms"][step % clipping_update_period] = tot_norm
if step % clipping_update_period == 0:
# Print some stats.
# We don't reach here if step == 0 because we would have returned
# above.
sorted_norms = first_state["model_norms"].sort()[0].to('cpu')
sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
quartiles = []
for n in range(0, 5):
index = min(clipping_update_period - 1,
(clipping_update_period // 4) * n)
index = min(
clipping_update_period - 1, (clipping_update_period // 4) * n
)
quartiles.append(sorted_norms[index].item())
median = quartiles[2]
threshold = clipping_scale * median
first_state["model_norm_threshold"] = threshold
percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period
if "num_clipped" in first_state else 0.0)
percent_clipped = (
first_state["num_clipped"] * 100.0 / clipping_update_period
if "num_clipped" in first_state
else 0.0
)
first_state["num_clipped"] = 0
quartiles = ' '.join([ '%.3e' % x for x in quartiles ])
logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}")
quartiles = " ".join(["%.3e" % x for x in quartiles])
logging.info(
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
)
if step < clipping_update_period:
return 1.0 # We have not yet estimated a norm to clip to.
else:
try:
model_norm_threshold = first_state["model_norm_threshold"]
except:
logging.info("Warning: model_norm_threshold not in state: possibly "
"you changed config when restarting, adding clipping_scale option?")
except KeyError:
logging.info(
"Warning: model_norm_threshold not in state: possibly "
"you changed config when restarting, adding clipping_scale option?"
)
return 1.0
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
if ans < 1.0:
first_state["num_clipped"] += 1
if ans < 0.1:
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
logging.warn(
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
)
if self.show_dominant_parameters:
assert p.shape[0] == len(param_names)
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
return ans
def _show_gradient_dominating_parameter(
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
):
"""
Show information of parameter wihch dominanting tot_sumsq.
def _step_one_batch(self,
group: dict,
p: Tensor,
state: dict,
clipping_scale: float):
Args:
tuples: a list of tuples of (param, state, param_names)
where param is a batched set of parameters,
with a .grad (1st dim is batch dim)
and state is the state-dict where optimization parameters are kept.
param_names is a List[str] while each str is name for a parameter
in batched set of parameters "param".
tot_sumsq: sumsq of all parameters. Though it's could be calculated
from tuples, we still pass it to save some time.
"""
all_sumsq_orig = {}
for (p, state, batch_param_names) in tuples:
# p is a stacked batch parameters.
batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
batch_sumsq_orig = batch_grad**2
# Dummpy values used by following `zip` statement.
batch_rms_orig = torch.ones(p.shape[0])
else:
batch_rms_orig = state["param_rms"]
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
dim=list(range(1, batch_grad.ndim))
)
for name, sumsq_orig, rms, grad in zip(
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
):
proportion_orig = sumsq_orig / tot_sumsq
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
assert torch.isclose(
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
torch.tensor(1.0),
)
sorted_by_proportion = {
k: v
for k, v in sorted(
all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True
)
}
dominant_param_name = next(iter(sorted_by_proportion))
(
dominant_proportion,
dominant_sumsq,
dominant_rms,
dominant_grad,
) = sorted_by_proportion[dominant_param_name]
logging.info(
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
f" with proportion {dominant_proportion:.2f},"
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
f"={dominant_sumsq:.3e},"
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
)
def _step_one_batch(
self, group: dict, p: Tensor, state: dict, clipping_scale: float
):
"""
Do the step for one parameter, which is actually going to be a batch of
`real` parameters, with dim 0 as the batch dim.
@ -391,17 +486,18 @@ class ScaledAdam(BatchedOptimizer):
# Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"]
scale_grads[step % size_update_period] = (p * grad).sum(
dim=list(range(1, p.ndim)), keepdim=True)
dim=list(range(1, p.ndim)), keepdim=True
)
if step % size_update_period == size_update_period - 1:
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)),
keepdim=True).sqrt())
param_rms.copy_(
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
)
if step > 0:
# self._size_update() learns the overall scale on the
# parameter, by shrinking or expanding it.
self._size_update(group, scale_grads, p, state)
if numel == 1:
# For parameters with 1 element we just use regular Adam.
# Updates delta.
@ -411,12 +507,9 @@ class ScaledAdam(BatchedOptimizer):
state["step"] = step + 1
def _size_update(self,
group: dict,
scale_grads: Tensor,
p: Tensor,
state: dict) -> None:
def _size_update(
self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
) -> None:
"""
Called only where p.numel() > 1, this updates the scale of the parameter.
If we imagine: p = underlying_param * scale.exp(), and we are doing
@ -448,7 +541,8 @@ class ScaledAdam(BatchedOptimizer):
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...)
alpha=1 - beta2_corr,
) # shape is (batch_size, 1, 1, ...)
# The 1st time we reach here is when size_step == 1.
size_step = (step + 1) // size_update_period
@ -458,10 +552,12 @@ class ScaledAdam(BatchedOptimizer):
denom = scale_exp_avg_sq.sqrt() + eps
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom
scale_step = (
-size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
)
is_too_small = (param_rms < param_min_rms)
is_too_large = (param_rms > param_max_rms)
is_too_small = param_rms < param_min_rms
is_too_large = param_rms > param_max_rms
# when the param gets too small, just don't shrink it any further.
scale_step.masked_fill_(is_too_small, 0.0)
@ -471,11 +567,7 @@ class ScaledAdam(BatchedOptimizer):
# the factor of (1-beta1) relates to momentum.
delta.add_(p * scale_step, alpha=(1 - beta1))
def _step(self,
group: dict,
p: Tensor,
state: dict):
def _step(self, group: dict, p: Tensor, state: dict):
"""
This function does the core update of self.step(), in the case where the members of
the batch have more than 1 element.
@ -496,8 +588,7 @@ class ScaledAdam(BatchedOptimizer):
step = state["step"]
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=(1-beta2))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
bias_correction2 = 1 - beta2 ** (this_step + 1)
@ -515,11 +606,7 @@ class ScaledAdam(BatchedOptimizer):
delta.add_(grad * alpha)
p.add_(delta)
def _step_scalar(self,
group: dict,
p: Tensor,
state: dict):
def _step_scalar(self, group: dict, p: Tensor, state: dict):
"""
A simplified form of the core update for scalar tensors, where we cannot get a good
estimate of the parameter rms.
@ -531,8 +618,7 @@ class ScaledAdam(BatchedOptimizer):
grad = p.grad
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=1-beta2)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
# slower update at the start will help stability anyway.
@ -545,7 +631,6 @@ class ScaledAdam(BatchedOptimizer):
p.add_(delta)
class LRScheduler(object):
"""
Base-class for learning rate schedulers where the learning-rate depends on both the
@ -555,18 +640,14 @@ class LRScheduler(object):
def __init__(self, optimizer: Optimizer, verbose: bool = False):
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError(
"{} is not an Optimizer".format(type(optimizer).__name__)
)
raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
self.optimizer = optimizer
self.verbose = verbose
for group in optimizer.param_groups:
group.setdefault("base_lr", group["lr"])
self.base_lrs = [
group["base_lr"] for group in optimizer.param_groups
]
self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
self.epoch = 0
self.batch = 0
@ -682,11 +763,13 @@ class Eden(LRScheduler):
factor = (
(self.batch**2 + self.lr_batches**2) / self.lr_batches**2
) ** -0.25 * (
((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
** -0.25
((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25
)
warmup_factor = (
1.0
if self.batch >= self.warmup_batches
else 0.5 + 0.5 * (self.batch / self.warmup_batches)
)
warmup_factor = (1.0 if self.batch >= self.warmup_batches
else 0.5 + 0.5 * (self.batch / self.warmup_batches))
return [x * factor * warmup_factor for x in self.base_lrs]
@ -745,13 +828,14 @@ class Eve(Optimizer):
parameters, if they fall below this we will stop applying weight decay.
.. _Adam\: A Method for Stochastic Optimization:
.. _Adam: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(
self,
params,
@ -766,17 +850,11 @@ class Eve(Optimizer):
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
"Invalid beta parameter at index 0: {}".format(betas[0])
)
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1])
)
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0 <= weight_decay <= 0.1:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0 < target_rms <= 10.0:
raise ValueError("Invalid target_rms value: {}".format(target_rms))
defaults = dict(
@ -812,9 +890,7 @@ class Eve(Optimizer):
# Perform optimization step
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"AdamW does not support sparse gradients"
)
raise RuntimeError("AdamW does not support sparse gradients")
state = self.state[p]
@ -852,30 +928,31 @@ class Eve(Optimizer):
if p.numel() > 1:
# avoid applying this weight-decay on "scaling factors"
# (which are scalar).
is_above_target_rms = p.norm() > (
target_rms * (p.numel() ** 0.5)
)
is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5))
p.mul_(1 - (weight_decay * is_above_target_rms))
p.addcdiv_(exp_avg, denom, value=-step_size)
if random.random() < 0.0005:
step = (exp_avg / denom) * step_size
logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}")
logging.info(
f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}"
)
return loss
def _test_scaled_adam(hidden_dim: int):
import timeit
from scaling import ScaledLinear
E = 100
B = 4
T = 2
logging.info("in test_eve_cain")
# device = torch.device('cuda')
device = torch.device('cpu')
device = torch.device("cpu")
dtype = torch.float32
fix_random_seed(42)
@ -889,21 +966,30 @@ def _test_scaled_adam(hidden_dim: int):
fix_random_seed(42)
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
m = torch.nn.Sequential(Linear(E, hidden_dim),
m = torch.nn.Sequential(
Linear(E, hidden_dim),
torch.nn.PReLU(),
Linear(hidden_dim, hidden_dim),
torch.nn.PReLU(),
Linear(hidden_dim, E),
).to(device)
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ]
train_pairs = [
(
100.0
* torch.randn(B, T, E, device=device, dtype=dtype)
* input_magnitudes,
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes,
)
for _ in range(20)
]
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
if iter == 0:
optim = Eve(m.parameters(), lr=0.003)
elif iter == 1:
optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
start = timeit.default_timer()
avg_loss = 0.0
for epoch in range(180):
@ -917,7 +1003,6 @@ def _test_scaled_adam(hidden_dim: int):
# ) # allow 4 megabytes per sub-module
# diagnostic = diagnostics.attach_diagnostics(m, opts)
for n, (x, y) in enumerate(train_pairs):
y_out = m(x)
loss = ((y_out - y) ** 2).mean() * 100.0
@ -935,7 +1020,9 @@ def _test_scaled_adam(hidden_dim: int):
# scale2 = '%.2e' % (m[2].weight_scale.exp().item())
# scale2b = '%.2e' % (m[2].bias_scale.exp().item())
lr = scheduler.get_last_lr()[0]
logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
logging.info(
f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}"
) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
loss.log().backward()
optim.step()
optim.zero_grad()
@ -953,15 +1040,18 @@ def _test_scaled_adam(hidden_dim: int):
logging.info(f"output_magnitudes = {output_magnitudes}")
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
logging.getLogger().setLevel(logging.INFO)
import subprocess
s = subprocess.check_output("git status -uno .; git log -1; git diff HEAD .", shell=True)
s = subprocess.check_output(
"git status -uno .; git log -1; git diff HEAD .", shell=True
)
logging.info(s)
import sys
if len(sys.argv) > 1:
hidden_dim = int(sys.argv[1])
else:

View File

@ -90,7 +90,6 @@ LRSchedulerType = Union[
]
def get_adjusted_batch_count(
params: AttributeDict) -> float:
# returns the number of batches we would have used so far if we had used the reference
@ -99,6 +98,7 @@ def get_adjusted_batch_count(
(params.max_duration * params.world_size))
def set_batch_count(
model: Union[nn.Module, DDP], batch_count: float
) -> None:
@ -856,7 +856,6 @@ def train_one_epoch(
for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params))
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
@ -879,7 +878,6 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
@ -1058,9 +1056,12 @@ def run(rank, world_size, args):
model = DDP(model, device_ids=[rank],
find_unused_parameters=True)
optimizer = ScaledAdam(model.parameters(),
optimizer = ScaledAdam(
model.parameters(),
lr=params.base_lr,
clipping_scale=2.0)
clipping_scale=2.0,
parameters_names=[ [p[0] for p in model.named_parameters()] ],
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)