mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge branch 'k2-fsa:master' into dev/k2ssl
This commit is contained in:
commit
e8fa10e53a
@ -121,6 +121,139 @@ class BatchedOptimizer(Optimizer):
|
|||||||
p.copy_(stacked_params[i])
|
p.copy_(stacked_params[i])
|
||||||
|
|
||||||
|
|
||||||
|
def basic_step(group, p, state, grad):
|
||||||
|
# computes basic Adam update using beta2 (dividing by gradient stddev) only. no momentum yet.
|
||||||
|
lr = group["lr"]
|
||||||
|
if p.numel() == p.shape[0]:
|
||||||
|
lr = lr * group["scalar_lr_scale"]
|
||||||
|
beta2 = group["betas"][1]
|
||||||
|
eps = group["eps"]
|
||||||
|
# p shape: (batch_size,) or (batch_size, 1, [1,..])
|
||||||
|
try:
|
||||||
|
exp_avg_sq = state[
|
||||||
|
"exp_avg_sq"
|
||||||
|
] # shape: (batch_size,) or (batch_size, 1, [1,..])
|
||||||
|
except KeyError:
|
||||||
|
exp_avg_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float)
|
||||||
|
state["exp_avg_sq"] = exp_avg_sq
|
||||||
|
|
||||||
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||||
|
|
||||||
|
# bias_correction2 is like in Adam.
|
||||||
|
# slower update at the start will help stability anyway.
|
||||||
|
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
||||||
|
if bias_correction2 < 0.99:
|
||||||
|
# note: not in-place.
|
||||||
|
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
||||||
|
denom = exp_avg_sq.sqrt().add_(eps)
|
||||||
|
|
||||||
|
return -lr * grad / denom
|
||||||
|
|
||||||
|
|
||||||
|
def scaling_step(group, p, state, grad):
|
||||||
|
delta = basic_step(group, p, state, grad)
|
||||||
|
if p.numel() == p.shape[0]:
|
||||||
|
return delta # there is no scaling for scalar parameters. (p.shape[0] is the batch of parameters.)
|
||||||
|
|
||||||
|
step = state["step"]
|
||||||
|
size_update_period = group["size_update_period"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
param_rms = state["param_rms"]
|
||||||
|
scale_grads = state["scale_grads"]
|
||||||
|
scale_exp_avg_sq = state["scale_exp_avg_sq"]
|
||||||
|
except KeyError:
|
||||||
|
# we know p.ndim > 1 because we'd have returned above if not, so don't worry
|
||||||
|
# about the speial case of dim=[] that pytorch treats inconsistently.
|
||||||
|
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
||||||
|
param_rms = param_rms.to(torch.float)
|
||||||
|
scale_exp_avg_sq = torch.zeros_like(param_rms)
|
||||||
|
scale_grads = torch.zeros(
|
||||||
|
size_update_period, *param_rms.shape, dtype=torch.float, device=p.device
|
||||||
|
)
|
||||||
|
state["param_rms"] = param_rms
|
||||||
|
state["scale_grads"] = scale_grads
|
||||||
|
state["scale_exp_avg_sq"] = scale_exp_avg_sq
|
||||||
|
|
||||||
|
# on every step, update the gradient w.r.t. the scale of the parameter, we
|
||||||
|
# store these as a batch and periodically update the size (for speed only, to
|
||||||
|
# avoid too many operations).
|
||||||
|
scale_grads[step % size_update_period] = (p * grad).sum(
|
||||||
|
dim=list(range(1, p.ndim)), keepdim=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# periodically recompute the value of param_rms.
|
||||||
|
if step % size_update_period == size_update_period - 1:
|
||||||
|
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
|
||||||
|
|
||||||
|
param_min_rms = group["param_min_rms"]
|
||||||
|
|
||||||
|
# scale the step size by param_rms. This is the most important "scaling" part of
|
||||||
|
# ScaledAdam
|
||||||
|
delta *= param_rms.clamp(min=param_min_rms)
|
||||||
|
|
||||||
|
if step % size_update_period == size_update_period - 1 and step > 0:
|
||||||
|
# This block updates the size of parameter by adding a step ("delta") value in
|
||||||
|
# the direction of either shrinking or growing it.
|
||||||
|
beta2 = group["betas"][1]
|
||||||
|
size_lr = group["lr"] * group["scalar_lr_scale"]
|
||||||
|
param_max_rms = group["param_max_rms"]
|
||||||
|
eps = group["eps"]
|
||||||
|
batch_size = p.shape[0]
|
||||||
|
# correct beta2 for the size update period: we will have
|
||||||
|
# faster decay at this level.
|
||||||
|
beta2_corr = beta2**size_update_period
|
||||||
|
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||||
|
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
|
||||||
|
alpha=1 - beta2_corr,
|
||||||
|
) # shape is (batch_size, 1, 1, ...)
|
||||||
|
|
||||||
|
# The 1st time we reach here is when size_step == 1.
|
||||||
|
size_step = (step + 1) // size_update_period
|
||||||
|
bias_correction2 = 1 - beta2_corr**size_step
|
||||||
|
|
||||||
|
denom = scale_exp_avg_sq.sqrt() + eps
|
||||||
|
|
||||||
|
scale_step = (
|
||||||
|
-size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
|
||||||
|
)
|
||||||
|
|
||||||
|
is_too_small = param_rms < param_min_rms
|
||||||
|
|
||||||
|
# when the param gets too small, just don't shrink it any further.
|
||||||
|
scale_step.masked_fill_(is_too_small, 0.0)
|
||||||
|
|
||||||
|
# The following may help prevent instability: don't allow the scale step to be too large in
|
||||||
|
# either direction.
|
||||||
|
scale_step.clamp_(min=-0.1, max=0.1)
|
||||||
|
|
||||||
|
# and ensure the parameter rms after update never exceeds param_max_rms.
|
||||||
|
# We have to look at the trained model for parameters at or around the
|
||||||
|
# param_max_rms, because sometimes they can indicate a problem with the
|
||||||
|
# topology or settings.
|
||||||
|
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
|
||||||
|
|
||||||
|
delta.add_(p * scale_step)
|
||||||
|
|
||||||
|
return delta
|
||||||
|
|
||||||
|
|
||||||
|
def momentum_step(group, p, state, grad):
|
||||||
|
delta = scaling_step(group, p, state, grad)
|
||||||
|
beta1 = group["betas"][0]
|
||||||
|
try:
|
||||||
|
stored_delta = state["delta"]
|
||||||
|
except KeyError:
|
||||||
|
stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float)
|
||||||
|
state["delta"] = stored_delta
|
||||||
|
stored_delta.mul_(beta1)
|
||||||
|
stored_delta.add_(delta, alpha=(1 - beta1))
|
||||||
|
# we don't bother doing the "bias correction" part of Adam for beta1 because this is just
|
||||||
|
# an edge effect that affects the first 10 or so batches; and the effect of not doing it
|
||||||
|
# is just to do a slower update for the first few batches, which will help stability.
|
||||||
|
return stored_delta
|
||||||
|
|
||||||
|
|
||||||
class ScaledAdam(BatchedOptimizer):
|
class ScaledAdam(BatchedOptimizer):
|
||||||
"""
|
"""
|
||||||
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
||||||
@ -352,58 +485,26 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"ScaledAdam optimizer does not support sparse gradients"
|
"ScaledAdam optimizer does not support sparse gradients"
|
||||||
)
|
)
|
||||||
# State initialization
|
|
||||||
if len(state) == 0:
|
|
||||||
self._init_state(group, p, state)
|
|
||||||
|
|
||||||
self._step_one_batch(group, p, state, clipping_scale)
|
try:
|
||||||
|
cur_step = state["step"]
|
||||||
|
except KeyError:
|
||||||
|
state["step"] = 0
|
||||||
|
cur_step = 0
|
||||||
|
|
||||||
|
grad = (
|
||||||
|
p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale)
|
||||||
|
)
|
||||||
|
p += momentum_step(group, p.detach(), state, grad)
|
||||||
|
|
||||||
|
if p.numel() == p.shape[0]: # scalar parameter
|
||||||
|
scalar_max = group["scalar_max"]
|
||||||
|
p.clamp_(min=-scalar_max, max=scalar_max)
|
||||||
|
|
||||||
|
state["step"] = cur_step + 1
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def _init_state(self, group: dict, p: Tensor, state: dict):
|
|
||||||
"""
|
|
||||||
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
|
|
||||||
is actually the batch dimension, corresponding to batched-together
|
|
||||||
parameters of a given shape.
|
|
||||||
|
|
||||||
|
|
||||||
Args:
|
|
||||||
group: Dict to look up configuration values.
|
|
||||||
p: The parameter that we are initializing the state for
|
|
||||||
state: Dict from string to whatever state we are initializing
|
|
||||||
"""
|
|
||||||
size_update_period = group["size_update_period"]
|
|
||||||
|
|
||||||
state["step"] = 0
|
|
||||||
|
|
||||||
kwargs = {"device": p.device, "dtype": p.dtype}
|
|
||||||
|
|
||||||
# 'delta' implements conventional momentum. There are
|
|
||||||
# several different kinds of update going on, so rather than
|
|
||||||
# compute "exp_avg" like in Adam, we store and decay a
|
|
||||||
# parameter-change "delta", which combines all forms of
|
|
||||||
# update. this is equivalent to how it's done in Adam,
|
|
||||||
# except for the first few steps.
|
|
||||||
state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
||||||
|
|
||||||
batch_size = p.shape[0]
|
|
||||||
numel = p.numel() // batch_size
|
|
||||||
|
|
||||||
if numel > 1:
|
|
||||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
|
||||||
# the parameter tensor.
|
|
||||||
# it has a shape like (batch_size, 1, 1, 1, 1)
|
|
||||||
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
|
||||||
state["param_rms"] = param_rms
|
|
||||||
|
|
||||||
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
|
||||||
state["scale_grads"] = torch.zeros(
|
|
||||||
size_update_period, *param_rms.shape, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
|
||||||
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
||||||
|
|
||||||
def _get_clipping_scale(
|
def _get_clipping_scale(
|
||||||
self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
|
self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
|
||||||
) -> float:
|
) -> float:
|
||||||
@ -484,7 +585,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
)
|
)
|
||||||
first_state["num_clipped"] = 0
|
first_state["num_clipped"] = 0
|
||||||
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
||||||
logging.warn(
|
logging.warning(
|
||||||
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
||||||
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
||||||
)
|
)
|
||||||
@ -499,8 +600,8 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
ans = 0.0
|
ans = 0.0
|
||||||
if ans < 1.0:
|
if ans < 1.0:
|
||||||
first_state["num_clipped"] += 1
|
first_state["num_clipped"] += 1
|
||||||
if ans < 0.1:
|
if ans < 0.5:
|
||||||
logging.warn(
|
logging.warning(
|
||||||
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
||||||
)
|
)
|
||||||
if self.show_dominant_parameters:
|
if self.show_dominant_parameters:
|
||||||
@ -508,6 +609,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
self._show_gradient_dominating_parameter(
|
self._show_gradient_dominating_parameter(
|
||||||
tuples, tot_sumsq, group["scalar_lr_scale"]
|
tuples, tot_sumsq, group["scalar_lr_scale"]
|
||||||
)
|
)
|
||||||
|
self._show_param_with_unusual_grad(tuples)
|
||||||
|
|
||||||
if ans == 0.0:
|
if ans == 0.0:
|
||||||
for (p, state, param_names) in tuples:
|
for (p, state, param_names) in tuples:
|
||||||
@ -515,6 +617,55 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
def _show_param_with_unusual_grad(
|
||||||
|
self,
|
||||||
|
tuples: List[Tuple[Tensor, dict, List[str]]],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Print information about parameter which has the largest ratio of grad-on-this-batch
|
||||||
|
divided by normal grad size.
|
||||||
|
tuples: a list of tuples of (param, state, param_names)
|
||||||
|
where param is a batched set of parameters,
|
||||||
|
with a .grad (1st dim is batch dim)
|
||||||
|
and state is the state-dict where optimization parameters are kept.
|
||||||
|
param_names is a List[str] while each str is name for a parameter
|
||||||
|
in batched set of parameters "param".
|
||||||
|
"""
|
||||||
|
largest_ratio = 0.0
|
||||||
|
largest_name = ""
|
||||||
|
# ratios_names is a list of 3-tuples: (grad_ratio, param_name, tensor)
|
||||||
|
ratios_names = []
|
||||||
|
for (p, state, batch_param_names) in tuples:
|
||||||
|
dims = list(range(1, p.ndim))
|
||||||
|
|
||||||
|
def mean(x):
|
||||||
|
# workaround for bad interface of torch's "mean" for when dims is the empty list.
|
||||||
|
if len(dims) > 0:
|
||||||
|
return x.mean(dim=dims)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
grad_ratio = (
|
||||||
|
(mean(p.grad**2) / state["exp_avg_sq"].mean(dim=dims))
|
||||||
|
.sqrt()
|
||||||
|
.to("cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
ratios_names += zip(
|
||||||
|
grad_ratio.tolist(), batch_param_names, p.grad.unbind(dim=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
ratios_names = sorted(ratios_names, reverse=True)
|
||||||
|
ratios_names = ratios_names[:10]
|
||||||
|
ratios_names = [
|
||||||
|
(ratio, name, largest_index(tensor))
|
||||||
|
for (ratio, name, tensor) in ratios_names
|
||||||
|
]
|
||||||
|
|
||||||
|
logging.warning(
|
||||||
|
f"Parameters with most larger-than-usual grads, with ratios, are: {ratios_names}"
|
||||||
|
)
|
||||||
|
|
||||||
def _show_gradient_dominating_parameter(
|
def _show_gradient_dominating_parameter(
|
||||||
self,
|
self,
|
||||||
tuples: List[Tuple[Tensor, dict, List[str]]],
|
tuples: List[Tuple[Tensor, dict, List[str]]],
|
||||||
@ -572,7 +723,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
dominant_rms,
|
dominant_rms,
|
||||||
dominant_grad,
|
dominant_grad,
|
||||||
) = sorted_by_proportion[dominant_param_name]
|
) = sorted_by_proportion[dominant_param_name]
|
||||||
logging.warn(
|
logging.warning(
|
||||||
f"Parameter dominating tot_sumsq {dominant_param_name}"
|
f"Parameter dominating tot_sumsq {dominant_param_name}"
|
||||||
f" with proportion {dominant_proportion:.2f},"
|
f" with proportion {dominant_proportion:.2f},"
|
||||||
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||||
@ -581,182 +732,11 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _step_one_batch(
|
|
||||||
self, group: dict, p: Tensor, state: dict, clipping_scale: float
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Do the step for one parameter, which is actually going to be a batch of
|
|
||||||
`real` parameters, with dim 0 as the batch dim.
|
|
||||||
Args:
|
|
||||||
group: dict to look up configuration values
|
|
||||||
p: parameter to update (actually multiple parameters stacked together
|
|
||||||
as a batch)
|
|
||||||
state: state-dict for p, to look up the optimizer state
|
|
||||||
"""
|
|
||||||
lr = group["lr"]
|
|
||||||
size_update_period = group["size_update_period"]
|
|
||||||
beta1 = group["betas"][0]
|
|
||||||
|
|
||||||
grad = p.grad
|
def largest_index(x: Tensor):
|
||||||
if clipping_scale != 1.0:
|
x = x.contiguous()
|
||||||
grad *= clipping_scale
|
argmax = x.abs().argmax().item()
|
||||||
step = state["step"]
|
return [(argmax // x.stride(i)) % x.size(i) for i in range(x.ndim)]
|
||||||
delta = state["delta"]
|
|
||||||
|
|
||||||
delta.mul_(beta1)
|
|
||||||
batch_size = p.shape[0]
|
|
||||||
numel = p.numel() // batch_size
|
|
||||||
if numel > 1:
|
|
||||||
# Update the size/scale of p, and set param_rms
|
|
||||||
scale_grads = state["scale_grads"]
|
|
||||||
scale_grads[step % size_update_period] = (p * grad).sum(
|
|
||||||
dim=list(range(1, p.ndim)), keepdim=True
|
|
||||||
)
|
|
||||||
if step % size_update_period == size_update_period - 1:
|
|
||||||
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
|
||||||
param_rms.copy_(
|
|
||||||
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
|
||||||
)
|
|
||||||
if step > 0:
|
|
||||||
# self._size_update() learns the overall scale on the
|
|
||||||
# parameter, by shrinking or expanding it.
|
|
||||||
self._size_update(group, scale_grads, p, state)
|
|
||||||
|
|
||||||
if numel == 1:
|
|
||||||
# For parameters with 1 element we just use regular Adam.
|
|
||||||
# Updates delta.
|
|
||||||
self._step_scalar(group, p, state)
|
|
||||||
else:
|
|
||||||
self._step(group, p, state)
|
|
||||||
|
|
||||||
state["step"] = step + 1
|
|
||||||
|
|
||||||
def _size_update(
|
|
||||||
self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Called only where p.numel() > 1, this updates the scale of the parameter.
|
|
||||||
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
|
||||||
gradient descent on underlying param and on scale, this function does the update
|
|
||||||
on `scale`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
group: dict to look up configuration values
|
|
||||||
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
|
|
||||||
grads w.r.t. the scales.
|
|
||||||
p: The parameter to update
|
|
||||||
state: The state-dict of p
|
|
||||||
"""
|
|
||||||
|
|
||||||
param_rms = state["param_rms"]
|
|
||||||
beta1, beta2 = group["betas"]
|
|
||||||
size_lr = group["lr"] * group["scalar_lr_scale"]
|
|
||||||
param_min_rms = group["param_min_rms"]
|
|
||||||
param_max_rms = group["param_max_rms"]
|
|
||||||
eps = group["eps"]
|
|
||||||
step = state["step"]
|
|
||||||
batch_size = p.shape[0]
|
|
||||||
|
|
||||||
size_update_period = scale_grads.shape[0]
|
|
||||||
# correct beta2 for the size update period: we will have
|
|
||||||
# faster decay at this level.
|
|
||||||
beta2_corr = beta2**size_update_period
|
|
||||||
|
|
||||||
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
|
||||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
|
||||||
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
|
|
||||||
alpha=1 - beta2_corr,
|
|
||||||
) # shape is (batch_size, 1, 1, ...)
|
|
||||||
|
|
||||||
# The 1st time we reach here is when size_step == 1.
|
|
||||||
size_step = (step + 1) // size_update_period
|
|
||||||
bias_correction2 = 1 - beta2_corr**size_step
|
|
||||||
# we don't bother with bias_correction1; this will help prevent divergence
|
|
||||||
# at the start of training.
|
|
||||||
|
|
||||||
denom = scale_exp_avg_sq.sqrt() + eps
|
|
||||||
|
|
||||||
scale_step = (
|
|
||||||
-size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
|
|
||||||
)
|
|
||||||
|
|
||||||
is_too_small = param_rms < param_min_rms
|
|
||||||
|
|
||||||
# when the param gets too small, just don't shrink it any further.
|
|
||||||
scale_step.masked_fill_(is_too_small, 0.0)
|
|
||||||
|
|
||||||
# and ensure the parameter rms after update never exceeds param_max_rms.
|
|
||||||
# We have to look at the trained model for parameters at or around the
|
|
||||||
# param_max_rms, because sometimes they can indicate a problem with the
|
|
||||||
# topology or settings.
|
|
||||||
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
|
|
||||||
|
|
||||||
delta = state["delta"]
|
|
||||||
# the factor of (1-beta1) relates to momentum.
|
|
||||||
delta.add_(p * scale_step, alpha=(1 - beta1))
|
|
||||||
|
|
||||||
def _step(self, group: dict, p: Tensor, state: dict):
|
|
||||||
"""
|
|
||||||
This function does the core update of self.step(), in the case where the members of
|
|
||||||
the batch have more than 1 element.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
group: A dict which will be used to look up configuration values
|
|
||||||
p: The parameter to be updated
|
|
||||||
grad: The grad of p
|
|
||||||
state: The state-dict corresponding to parameter p
|
|
||||||
|
|
||||||
This function modifies p.
|
|
||||||
"""
|
|
||||||
grad = p.grad
|
|
||||||
lr = group["lr"]
|
|
||||||
beta1, beta2 = group["betas"]
|
|
||||||
eps = group["eps"]
|
|
||||||
param_min_rms = group["param_min_rms"]
|
|
||||||
step = state["step"]
|
|
||||||
|
|
||||||
exp_avg_sq = state["exp_avg_sq"]
|
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
|
||||||
|
|
||||||
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
|
|
||||||
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
|
||||||
if bias_correction2 < 0.99:
|
|
||||||
# note: not in-place.
|
|
||||||
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
|
||||||
|
|
||||||
denom = exp_avg_sq.sqrt()
|
|
||||||
denom += eps
|
|
||||||
grad = grad / denom
|
|
||||||
|
|
||||||
alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
|
|
||||||
|
|
||||||
delta = state["delta"]
|
|
||||||
delta.add_(grad * alpha)
|
|
||||||
p.add_(delta)
|
|
||||||
|
|
||||||
def _step_scalar(self, group: dict, p: Tensor, state: dict):
|
|
||||||
"""
|
|
||||||
A simplified form of the core update for scalar tensors, where we cannot get a good
|
|
||||||
estimate of the parameter rms.
|
|
||||||
"""
|
|
||||||
beta1, beta2 = group["betas"]
|
|
||||||
scalar_max = group["scalar_max"]
|
|
||||||
eps = group["eps"]
|
|
||||||
lr = group["lr"] * group["scalar_lr_scale"]
|
|
||||||
grad = p.grad
|
|
||||||
|
|
||||||
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
|
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
||||||
|
|
||||||
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
|
||||||
# slower update at the start will help stability anyway.
|
|
||||||
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
|
||||||
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
|
||||||
|
|
||||||
delta = state["delta"]
|
|
||||||
delta.add_(grad / denom, alpha=-lr * (1 - beta1))
|
|
||||||
p.clamp_(min=-scalar_max, max=scalar_max)
|
|
||||||
p.add_(delta)
|
|
||||||
|
|
||||||
|
|
||||||
class LRScheduler(object):
|
class LRScheduler(object):
|
||||||
@ -787,9 +767,9 @@ class LRScheduler(object):
|
|||||||
is not the optimizer.
|
is not the optimizer.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
# the user might try to override the base_lr, so don't include this in the state.
|
# the user might try to override the base_lr, so don't include this in the state.
|
||||||
# previously they were included.
|
# previously they were included.
|
||||||
# "base_lrs": self.base_lrs,
|
# "base_lrs": self.base_lrs,
|
||||||
"epoch": self.epoch,
|
"epoch": self.epoch,
|
||||||
"batch": self.batch,
|
"batch": self.batch,
|
||||||
}
|
}
|
||||||
@ -807,7 +787,6 @@ class LRScheduler(object):
|
|||||||
self.__dict__.update(state_dict)
|
self.__dict__.update(state_dict)
|
||||||
self.base_lrs = base_lrs
|
self.base_lrs = base_lrs
|
||||||
|
|
||||||
|
|
||||||
def get_last_lr(self) -> List[float]:
|
def get_last_lr(self) -> List[float]:
|
||||||
"""Return last computed learning rate by current scheduler. Will be a list of float."""
|
"""Return last computed learning rate by current scheduler. Will be a list of float."""
|
||||||
return self._last_lr
|
return self._last_lr
|
||||||
@ -853,7 +832,7 @@ class LRScheduler(object):
|
|||||||
def print_lr(self, is_verbose, group, lr):
|
def print_lr(self, is_verbose, group, lr):
|
||||||
"""Display the current learning rate."""
|
"""Display the current learning rate."""
|
||||||
if is_verbose:
|
if is_verbose:
|
||||||
logging.warn(
|
logging.warning(
|
||||||
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
||||||
f" of group {group} to {lr:.4e}."
|
f" of group {group} to {lr:.4e}."
|
||||||
)
|
)
|
||||||
@ -1184,7 +1163,7 @@ def _test_scaled_adam(hidden_dim: int):
|
|||||||
if iter == 0:
|
if iter == 0:
|
||||||
optim = Eve(m.parameters(), lr=0.003)
|
optim = Eve(m.parameters(), lr=0.003)
|
||||||
elif iter == 1:
|
elif iter == 1:
|
||||||
optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
|
optim = ScaledAdam(m.named_parameters(), lr=0.03, clipping_scale=2.0)
|
||||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
||||||
|
|
||||||
start = timeit.default_timer()
|
start = timeit.default_timer()
|
||||||
|
@ -10,9 +10,9 @@ stop_stage=4
|
|||||||
|
|
||||||
dl_dir=$PWD/download
|
dl_dir=$PWD/download
|
||||||
|
|
||||||
dataset_parts="Premium" # Basic for all 10k hours data, Premium for about 10% of the data
|
dataset_parts="Premium" # Basic for all 7226 hours data, Premium for 945 hours subset.
|
||||||
|
|
||||||
text_extractor="pypinyin_initials_finals" # default is espeak for English
|
text_extractor="pypinyin_initials_finals" # default is espeak for English
|
||||||
audio_extractor="Encodec" # or Fbank
|
audio_extractor="Encodec" # or Fbank
|
||||||
audio_feats_dir=data/tokenized
|
audio_feats_dir=data/tokenized
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
--audio-extractor ${audio_extractor} \
|
--audio-extractor ${audio_extractor} \
|
||||||
--batch-duration 2500 --prefix "wenetspeech4tts" \
|
--batch-duration 2500 --prefix "wenetspeech4tts" \
|
||||||
--src-dir "data/manifests" \
|
--src-dir "data/manifests" \
|
||||||
--split 100 \
|
--split 100 \
|
||||||
--output-dir "${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100"
|
--output-dir "${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100"
|
||||||
cp ${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100/unique_text_tokens.k2symbols ${audio_feats_dir}
|
cp ${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100/unique_text_tokens.k2symbols ${audio_feats_dir}
|
||||||
fi
|
fi
|
||||||
|
@ -63,12 +63,22 @@ def get_tensor_stats(
|
|||||||
"rms" -> square before summing, we'll take sqrt later
|
"rms" -> square before summing, we'll take sqrt later
|
||||||
"value" -> just sum x itself
|
"value" -> just sum x itself
|
||||||
"max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing
|
"max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing
|
||||||
|
"rms-sort" -> this is a bit different than the others, it's based on computing the
|
||||||
|
rms over the specified dim and returning percentiles of the result (11 of them).
|
||||||
Returns:
|
Returns:
|
||||||
stats: a Tensor of shape (x.shape[dim],).
|
stats: a Tensor of shape (x.shape[dim],).
|
||||||
count: an integer saying how many items were counted in each element
|
count: an integer saying how many items were counted in each element
|
||||||
of stats.
|
of stats.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if stats_type == "rms-sort":
|
||||||
|
rms = (x**2).mean(dim=dim).sqrt()
|
||||||
|
rms = rms.flatten()
|
||||||
|
rms = rms.sort()[0]
|
||||||
|
rms = rms[(torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1)]
|
||||||
|
count = 1.0
|
||||||
|
return rms, count
|
||||||
|
|
||||||
count = x.numel() // x.shape[dim]
|
count = x.numel() // x.shape[dim]
|
||||||
|
|
||||||
if stats_type == "eigs":
|
if stats_type == "eigs":
|
||||||
@ -164,7 +174,17 @@ class TensorDiagnostic(object):
|
|||||||
for dim in range(ndim):
|
for dim in range(ndim):
|
||||||
this_dim_stats = self.stats[dim]
|
this_dim_stats = self.stats[dim]
|
||||||
if ndim > 1:
|
if ndim > 1:
|
||||||
stats_types = ["abs", "max", "min", "positive", "value", "rms"]
|
# rms-sort is different from the others, it's based on summing over just this
|
||||||
|
# dim, then sorting and returning the percentiles.
|
||||||
|
stats_types = [
|
||||||
|
"abs",
|
||||||
|
"max",
|
||||||
|
"min",
|
||||||
|
"positive",
|
||||||
|
"value",
|
||||||
|
"rms",
|
||||||
|
"rms-sort",
|
||||||
|
]
|
||||||
if x.shape[dim] <= self.opts.max_eig_dim:
|
if x.shape[dim] <= self.opts.max_eig_dim:
|
||||||
stats_types.append("eigs")
|
stats_types.append("eigs")
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user