mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge branch 'k2-fsa:master' into dev/k2ssl
This commit is contained in:
commit
e8fa10e53a
@ -121,6 +121,139 @@ class BatchedOptimizer(Optimizer):
|
||||
p.copy_(stacked_params[i])
|
||||
|
||||
|
||||
def basic_step(group, p, state, grad):
|
||||
# computes basic Adam update using beta2 (dividing by gradient stddev) only. no momentum yet.
|
||||
lr = group["lr"]
|
||||
if p.numel() == p.shape[0]:
|
||||
lr = lr * group["scalar_lr_scale"]
|
||||
beta2 = group["betas"][1]
|
||||
eps = group["eps"]
|
||||
# p shape: (batch_size,) or (batch_size, 1, [1,..])
|
||||
try:
|
||||
exp_avg_sq = state[
|
||||
"exp_avg_sq"
|
||||
] # shape: (batch_size,) or (batch_size, 1, [1,..])
|
||||
except KeyError:
|
||||
exp_avg_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float)
|
||||
state["exp_avg_sq"] = exp_avg_sq
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
|
||||
# bias_correction2 is like in Adam.
|
||||
# slower update at the start will help stability anyway.
|
||||
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
||||
if bias_correction2 < 0.99:
|
||||
# note: not in-place.
|
||||
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
||||
denom = exp_avg_sq.sqrt().add_(eps)
|
||||
|
||||
return -lr * grad / denom
|
||||
|
||||
|
||||
def scaling_step(group, p, state, grad):
|
||||
delta = basic_step(group, p, state, grad)
|
||||
if p.numel() == p.shape[0]:
|
||||
return delta # there is no scaling for scalar parameters. (p.shape[0] is the batch of parameters.)
|
||||
|
||||
step = state["step"]
|
||||
size_update_period = group["size_update_period"]
|
||||
|
||||
try:
|
||||
param_rms = state["param_rms"]
|
||||
scale_grads = state["scale_grads"]
|
||||
scale_exp_avg_sq = state["scale_exp_avg_sq"]
|
||||
except KeyError:
|
||||
# we know p.ndim > 1 because we'd have returned above if not, so don't worry
|
||||
# about the speial case of dim=[] that pytorch treats inconsistently.
|
||||
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
||||
param_rms = param_rms.to(torch.float)
|
||||
scale_exp_avg_sq = torch.zeros_like(param_rms)
|
||||
scale_grads = torch.zeros(
|
||||
size_update_period, *param_rms.shape, dtype=torch.float, device=p.device
|
||||
)
|
||||
state["param_rms"] = param_rms
|
||||
state["scale_grads"] = scale_grads
|
||||
state["scale_exp_avg_sq"] = scale_exp_avg_sq
|
||||
|
||||
# on every step, update the gradient w.r.t. the scale of the parameter, we
|
||||
# store these as a batch and periodically update the size (for speed only, to
|
||||
# avoid too many operations).
|
||||
scale_grads[step % size_update_period] = (p * grad).sum(
|
||||
dim=list(range(1, p.ndim)), keepdim=True
|
||||
)
|
||||
|
||||
# periodically recompute the value of param_rms.
|
||||
if step % size_update_period == size_update_period - 1:
|
||||
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
|
||||
|
||||
param_min_rms = group["param_min_rms"]
|
||||
|
||||
# scale the step size by param_rms. This is the most important "scaling" part of
|
||||
# ScaledAdam
|
||||
delta *= param_rms.clamp(min=param_min_rms)
|
||||
|
||||
if step % size_update_period == size_update_period - 1 and step > 0:
|
||||
# This block updates the size of parameter by adding a step ("delta") value in
|
||||
# the direction of either shrinking or growing it.
|
||||
beta2 = group["betas"][1]
|
||||
size_lr = group["lr"] * group["scalar_lr_scale"]
|
||||
param_max_rms = group["param_max_rms"]
|
||||
eps = group["eps"]
|
||||
batch_size = p.shape[0]
|
||||
# correct beta2 for the size update period: we will have
|
||||
# faster decay at this level.
|
||||
beta2_corr = beta2**size_update_period
|
||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
|
||||
alpha=1 - beta2_corr,
|
||||
) # shape is (batch_size, 1, 1, ...)
|
||||
|
||||
# The 1st time we reach here is when size_step == 1.
|
||||
size_step = (step + 1) // size_update_period
|
||||
bias_correction2 = 1 - beta2_corr**size_step
|
||||
|
||||
denom = scale_exp_avg_sq.sqrt() + eps
|
||||
|
||||
scale_step = (
|
||||
-size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
|
||||
)
|
||||
|
||||
is_too_small = param_rms < param_min_rms
|
||||
|
||||
# when the param gets too small, just don't shrink it any further.
|
||||
scale_step.masked_fill_(is_too_small, 0.0)
|
||||
|
||||
# The following may help prevent instability: don't allow the scale step to be too large in
|
||||
# either direction.
|
||||
scale_step.clamp_(min=-0.1, max=0.1)
|
||||
|
||||
# and ensure the parameter rms after update never exceeds param_max_rms.
|
||||
# We have to look at the trained model for parameters at or around the
|
||||
# param_max_rms, because sometimes they can indicate a problem with the
|
||||
# topology or settings.
|
||||
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
|
||||
|
||||
delta.add_(p * scale_step)
|
||||
|
||||
return delta
|
||||
|
||||
|
||||
def momentum_step(group, p, state, grad):
|
||||
delta = scaling_step(group, p, state, grad)
|
||||
beta1 = group["betas"][0]
|
||||
try:
|
||||
stored_delta = state["delta"]
|
||||
except KeyError:
|
||||
stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float)
|
||||
state["delta"] = stored_delta
|
||||
stored_delta.mul_(beta1)
|
||||
stored_delta.add_(delta, alpha=(1 - beta1))
|
||||
# we don't bother doing the "bias correction" part of Adam for beta1 because this is just
|
||||
# an edge effect that affects the first 10 or so batches; and the effect of not doing it
|
||||
# is just to do a slower update for the first few batches, which will help stability.
|
||||
return stored_delta
|
||||
|
||||
|
||||
class ScaledAdam(BatchedOptimizer):
|
||||
"""
|
||||
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
||||
@ -352,58 +485,26 @@ class ScaledAdam(BatchedOptimizer):
|
||||
raise RuntimeError(
|
||||
"ScaledAdam optimizer does not support sparse gradients"
|
||||
)
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
self._init_state(group, p, state)
|
||||
|
||||
self._step_one_batch(group, p, state, clipping_scale)
|
||||
try:
|
||||
cur_step = state["step"]
|
||||
except KeyError:
|
||||
state["step"] = 0
|
||||
cur_step = 0
|
||||
|
||||
grad = (
|
||||
p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale)
|
||||
)
|
||||
p += momentum_step(group, p.detach(), state, grad)
|
||||
|
||||
if p.numel() == p.shape[0]: # scalar parameter
|
||||
scalar_max = group["scalar_max"]
|
||||
p.clamp_(min=-scalar_max, max=scalar_max)
|
||||
|
||||
state["step"] = cur_step + 1
|
||||
|
||||
return loss
|
||||
|
||||
def _init_state(self, group: dict, p: Tensor, state: dict):
|
||||
"""
|
||||
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
|
||||
is actually the batch dimension, corresponding to batched-together
|
||||
parameters of a given shape.
|
||||
|
||||
|
||||
Args:
|
||||
group: Dict to look up configuration values.
|
||||
p: The parameter that we are initializing the state for
|
||||
state: Dict from string to whatever state we are initializing
|
||||
"""
|
||||
size_update_period = group["size_update_period"]
|
||||
|
||||
state["step"] = 0
|
||||
|
||||
kwargs = {"device": p.device, "dtype": p.dtype}
|
||||
|
||||
# 'delta' implements conventional momentum. There are
|
||||
# several different kinds of update going on, so rather than
|
||||
# compute "exp_avg" like in Adam, we store and decay a
|
||||
# parameter-change "delta", which combines all forms of
|
||||
# update. this is equivalent to how it's done in Adam,
|
||||
# except for the first few steps.
|
||||
state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
|
||||
batch_size = p.shape[0]
|
||||
numel = p.numel() // batch_size
|
||||
|
||||
if numel > 1:
|
||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||
# the parameter tensor.
|
||||
# it has a shape like (batch_size, 1, 1, 1, 1)
|
||||
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
||||
state["param_rms"] = param_rms
|
||||
|
||||
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
||||
state["scale_grads"] = torch.zeros(
|
||||
size_update_period, *param_rms.shape, **kwargs
|
||||
)
|
||||
|
||||
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
||||
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
|
||||
def _get_clipping_scale(
|
||||
self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
|
||||
) -> float:
|
||||
@ -484,7 +585,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
)
|
||||
first_state["num_clipped"] = 0
|
||||
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
||||
logging.warn(
|
||||
logging.warning(
|
||||
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
||||
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
||||
)
|
||||
@ -499,8 +600,8 @@ class ScaledAdam(BatchedOptimizer):
|
||||
ans = 0.0
|
||||
if ans < 1.0:
|
||||
first_state["num_clipped"] += 1
|
||||
if ans < 0.1:
|
||||
logging.warn(
|
||||
if ans < 0.5:
|
||||
logging.warning(
|
||||
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
||||
)
|
||||
if self.show_dominant_parameters:
|
||||
@ -508,6 +609,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
self._show_gradient_dominating_parameter(
|
||||
tuples, tot_sumsq, group["scalar_lr_scale"]
|
||||
)
|
||||
self._show_param_with_unusual_grad(tuples)
|
||||
|
||||
if ans == 0.0:
|
||||
for (p, state, param_names) in tuples:
|
||||
@ -515,6 +617,55 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
return ans
|
||||
|
||||
def _show_param_with_unusual_grad(
|
||||
self,
|
||||
tuples: List[Tuple[Tensor, dict, List[str]]],
|
||||
):
|
||||
"""
|
||||
Print information about parameter which has the largest ratio of grad-on-this-batch
|
||||
divided by normal grad size.
|
||||
tuples: a list of tuples of (param, state, param_names)
|
||||
where param is a batched set of parameters,
|
||||
with a .grad (1st dim is batch dim)
|
||||
and state is the state-dict where optimization parameters are kept.
|
||||
param_names is a List[str] while each str is name for a parameter
|
||||
in batched set of parameters "param".
|
||||
"""
|
||||
largest_ratio = 0.0
|
||||
largest_name = ""
|
||||
# ratios_names is a list of 3-tuples: (grad_ratio, param_name, tensor)
|
||||
ratios_names = []
|
||||
for (p, state, batch_param_names) in tuples:
|
||||
dims = list(range(1, p.ndim))
|
||||
|
||||
def mean(x):
|
||||
# workaround for bad interface of torch's "mean" for when dims is the empty list.
|
||||
if len(dims) > 0:
|
||||
return x.mean(dim=dims)
|
||||
else:
|
||||
return x
|
||||
|
||||
grad_ratio = (
|
||||
(mean(p.grad**2) / state["exp_avg_sq"].mean(dim=dims))
|
||||
.sqrt()
|
||||
.to("cpu")
|
||||
)
|
||||
|
||||
ratios_names += zip(
|
||||
grad_ratio.tolist(), batch_param_names, p.grad.unbind(dim=0)
|
||||
)
|
||||
|
||||
ratios_names = sorted(ratios_names, reverse=True)
|
||||
ratios_names = ratios_names[:10]
|
||||
ratios_names = [
|
||||
(ratio, name, largest_index(tensor))
|
||||
for (ratio, name, tensor) in ratios_names
|
||||
]
|
||||
|
||||
logging.warning(
|
||||
f"Parameters with most larger-than-usual grads, with ratios, are: {ratios_names}"
|
||||
)
|
||||
|
||||
def _show_gradient_dominating_parameter(
|
||||
self,
|
||||
tuples: List[Tuple[Tensor, dict, List[str]]],
|
||||
@ -572,7 +723,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
dominant_rms,
|
||||
dominant_grad,
|
||||
) = sorted_by_proportion[dominant_param_name]
|
||||
logging.warn(
|
||||
logging.warning(
|
||||
f"Parameter dominating tot_sumsq {dominant_param_name}"
|
||||
f" with proportion {dominant_proportion:.2f},"
|
||||
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||
@ -581,182 +732,11 @@ class ScaledAdam(BatchedOptimizer):
|
||||
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
||||
)
|
||||
|
||||
def _step_one_batch(
|
||||
self, group: dict, p: Tensor, state: dict, clipping_scale: float
|
||||
):
|
||||
"""
|
||||
Do the step for one parameter, which is actually going to be a batch of
|
||||
`real` parameters, with dim 0 as the batch dim.
|
||||
Args:
|
||||
group: dict to look up configuration values
|
||||
p: parameter to update (actually multiple parameters stacked together
|
||||
as a batch)
|
||||
state: state-dict for p, to look up the optimizer state
|
||||
"""
|
||||
lr = group["lr"]
|
||||
size_update_period = group["size_update_period"]
|
||||
beta1 = group["betas"][0]
|
||||
|
||||
grad = p.grad
|
||||
if clipping_scale != 1.0:
|
||||
grad *= clipping_scale
|
||||
step = state["step"]
|
||||
delta = state["delta"]
|
||||
|
||||
delta.mul_(beta1)
|
||||
batch_size = p.shape[0]
|
||||
numel = p.numel() // batch_size
|
||||
if numel > 1:
|
||||
# Update the size/scale of p, and set param_rms
|
||||
scale_grads = state["scale_grads"]
|
||||
scale_grads[step % size_update_period] = (p * grad).sum(
|
||||
dim=list(range(1, p.ndim)), keepdim=True
|
||||
)
|
||||
if step % size_update_period == size_update_period - 1:
|
||||
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
||||
param_rms.copy_(
|
||||
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
||||
)
|
||||
if step > 0:
|
||||
# self._size_update() learns the overall scale on the
|
||||
# parameter, by shrinking or expanding it.
|
||||
self._size_update(group, scale_grads, p, state)
|
||||
|
||||
if numel == 1:
|
||||
# For parameters with 1 element we just use regular Adam.
|
||||
# Updates delta.
|
||||
self._step_scalar(group, p, state)
|
||||
else:
|
||||
self._step(group, p, state)
|
||||
|
||||
state["step"] = step + 1
|
||||
|
||||
def _size_update(
|
||||
self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
|
||||
) -> None:
|
||||
"""
|
||||
Called only where p.numel() > 1, this updates the scale of the parameter.
|
||||
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
||||
gradient descent on underlying param and on scale, this function does the update
|
||||
on `scale`.
|
||||
|
||||
Args:
|
||||
group: dict to look up configuration values
|
||||
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
|
||||
grads w.r.t. the scales.
|
||||
p: The parameter to update
|
||||
state: The state-dict of p
|
||||
"""
|
||||
|
||||
param_rms = state["param_rms"]
|
||||
beta1, beta2 = group["betas"]
|
||||
size_lr = group["lr"] * group["scalar_lr_scale"]
|
||||
param_min_rms = group["param_min_rms"]
|
||||
param_max_rms = group["param_max_rms"]
|
||||
eps = group["eps"]
|
||||
step = state["step"]
|
||||
batch_size = p.shape[0]
|
||||
|
||||
size_update_period = scale_grads.shape[0]
|
||||
# correct beta2 for the size update period: we will have
|
||||
# faster decay at this level.
|
||||
beta2_corr = beta2**size_update_period
|
||||
|
||||
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
|
||||
alpha=1 - beta2_corr,
|
||||
) # shape is (batch_size, 1, 1, ...)
|
||||
|
||||
# The 1st time we reach here is when size_step == 1.
|
||||
size_step = (step + 1) // size_update_period
|
||||
bias_correction2 = 1 - beta2_corr**size_step
|
||||
# we don't bother with bias_correction1; this will help prevent divergence
|
||||
# at the start of training.
|
||||
|
||||
denom = scale_exp_avg_sq.sqrt() + eps
|
||||
|
||||
scale_step = (
|
||||
-size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
|
||||
)
|
||||
|
||||
is_too_small = param_rms < param_min_rms
|
||||
|
||||
# when the param gets too small, just don't shrink it any further.
|
||||
scale_step.masked_fill_(is_too_small, 0.0)
|
||||
|
||||
# and ensure the parameter rms after update never exceeds param_max_rms.
|
||||
# We have to look at the trained model for parameters at or around the
|
||||
# param_max_rms, because sometimes they can indicate a problem with the
|
||||
# topology or settings.
|
||||
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
|
||||
|
||||
delta = state["delta"]
|
||||
# the factor of (1-beta1) relates to momentum.
|
||||
delta.add_(p * scale_step, alpha=(1 - beta1))
|
||||
|
||||
def _step(self, group: dict, p: Tensor, state: dict):
|
||||
"""
|
||||
This function does the core update of self.step(), in the case where the members of
|
||||
the batch have more than 1 element.
|
||||
|
||||
Args:
|
||||
group: A dict which will be used to look up configuration values
|
||||
p: The parameter to be updated
|
||||
grad: The grad of p
|
||||
state: The state-dict corresponding to parameter p
|
||||
|
||||
This function modifies p.
|
||||
"""
|
||||
grad = p.grad
|
||||
lr = group["lr"]
|
||||
beta1, beta2 = group["betas"]
|
||||
eps = group["eps"]
|
||||
param_min_rms = group["param_min_rms"]
|
||||
step = state["step"]
|
||||
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
||||
|
||||
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
|
||||
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
||||
if bias_correction2 < 0.99:
|
||||
# note: not in-place.
|
||||
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
||||
|
||||
denom = exp_avg_sq.sqrt()
|
||||
denom += eps
|
||||
grad = grad / denom
|
||||
|
||||
alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
|
||||
|
||||
delta = state["delta"]
|
||||
delta.add_(grad * alpha)
|
||||
p.add_(delta)
|
||||
|
||||
def _step_scalar(self, group: dict, p: Tensor, state: dict):
|
||||
"""
|
||||
A simplified form of the core update for scalar tensors, where we cannot get a good
|
||||
estimate of the parameter rms.
|
||||
"""
|
||||
beta1, beta2 = group["betas"]
|
||||
scalar_max = group["scalar_max"]
|
||||
eps = group["eps"]
|
||||
lr = group["lr"] * group["scalar_lr_scale"]
|
||||
grad = p.grad
|
||||
|
||||
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
|
||||
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
||||
# slower update at the start will help stability anyway.
|
||||
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
||||
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
||||
|
||||
delta = state["delta"]
|
||||
delta.add_(grad / denom, alpha=-lr * (1 - beta1))
|
||||
p.clamp_(min=-scalar_max, max=scalar_max)
|
||||
p.add_(delta)
|
||||
def largest_index(x: Tensor):
|
||||
x = x.contiguous()
|
||||
argmax = x.abs().argmax().item()
|
||||
return [(argmax // x.stride(i)) % x.size(i) for i in range(x.ndim)]
|
||||
|
||||
|
||||
class LRScheduler(object):
|
||||
@ -787,9 +767,9 @@ class LRScheduler(object):
|
||||
is not the optimizer.
|
||||
"""
|
||||
return {
|
||||
# the user might try to override the base_lr, so don't include this in the state.
|
||||
# previously they were included.
|
||||
# "base_lrs": self.base_lrs,
|
||||
# the user might try to override the base_lr, so don't include this in the state.
|
||||
# previously they were included.
|
||||
# "base_lrs": self.base_lrs,
|
||||
"epoch": self.epoch,
|
||||
"batch": self.batch,
|
||||
}
|
||||
@ -807,7 +787,6 @@ class LRScheduler(object):
|
||||
self.__dict__.update(state_dict)
|
||||
self.base_lrs = base_lrs
|
||||
|
||||
|
||||
def get_last_lr(self) -> List[float]:
|
||||
"""Return last computed learning rate by current scheduler. Will be a list of float."""
|
||||
return self._last_lr
|
||||
@ -853,7 +832,7 @@ class LRScheduler(object):
|
||||
def print_lr(self, is_verbose, group, lr):
|
||||
"""Display the current learning rate."""
|
||||
if is_verbose:
|
||||
logging.warn(
|
||||
logging.warning(
|
||||
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
||||
f" of group {group} to {lr:.4e}."
|
||||
)
|
||||
@ -1184,7 +1163,7 @@ def _test_scaled_adam(hidden_dim: int):
|
||||
if iter == 0:
|
||||
optim = Eve(m.parameters(), lr=0.003)
|
||||
elif iter == 1:
|
||||
optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
|
||||
optim = ScaledAdam(m.named_parameters(), lr=0.03, clipping_scale=2.0)
|
||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
||||
|
||||
start = timeit.default_timer()
|
||||
|
@ -10,9 +10,9 @@ stop_stage=4
|
||||
|
||||
dl_dir=$PWD/download
|
||||
|
||||
dataset_parts="Premium" # Basic for all 10k hours data, Premium for about 10% of the data
|
||||
dataset_parts="Premium" # Basic for all 7226 hours data, Premium for 945 hours subset.
|
||||
|
||||
text_extractor="pypinyin_initials_finals" # default is espeak for English
|
||||
text_extractor="pypinyin_initials_finals" # default is espeak for English
|
||||
audio_extractor="Encodec" # or Fbank
|
||||
audio_feats_dir=data/tokenized
|
||||
|
||||
@ -63,7 +63,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
--audio-extractor ${audio_extractor} \
|
||||
--batch-duration 2500 --prefix "wenetspeech4tts" \
|
||||
--src-dir "data/manifests" \
|
||||
--split 100 \
|
||||
--split 100 \
|
||||
--output-dir "${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100"
|
||||
cp ${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100/unique_text_tokens.k2symbols ${audio_feats_dir}
|
||||
fi
|
||||
|
@ -63,12 +63,22 @@ def get_tensor_stats(
|
||||
"rms" -> square before summing, we'll take sqrt later
|
||||
"value" -> just sum x itself
|
||||
"max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing
|
||||
"rms-sort" -> this is a bit different than the others, it's based on computing the
|
||||
rms over the specified dim and returning percentiles of the result (11 of them).
|
||||
Returns:
|
||||
stats: a Tensor of shape (x.shape[dim],).
|
||||
count: an integer saying how many items were counted in each element
|
||||
of stats.
|
||||
"""
|
||||
|
||||
if stats_type == "rms-sort":
|
||||
rms = (x**2).mean(dim=dim).sqrt()
|
||||
rms = rms.flatten()
|
||||
rms = rms.sort()[0]
|
||||
rms = rms[(torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1)]
|
||||
count = 1.0
|
||||
return rms, count
|
||||
|
||||
count = x.numel() // x.shape[dim]
|
||||
|
||||
if stats_type == "eigs":
|
||||
@ -164,7 +174,17 @@ class TensorDiagnostic(object):
|
||||
for dim in range(ndim):
|
||||
this_dim_stats = self.stats[dim]
|
||||
if ndim > 1:
|
||||
stats_types = ["abs", "max", "min", "positive", "value", "rms"]
|
||||
# rms-sort is different from the others, it's based on summing over just this
|
||||
# dim, then sorting and returning the percentiles.
|
||||
stats_types = [
|
||||
"abs",
|
||||
"max",
|
||||
"min",
|
||||
"positive",
|
||||
"value",
|
||||
"rms",
|
||||
"rms-sort",
|
||||
]
|
||||
if x.shape[dim] <= self.opts.max_eig_dim:
|
||||
stats_types.append("eigs")
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user