mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Committing partial work..
This commit is contained in:
parent
d25df4af5e
commit
27da50a1f6
@ -119,8 +119,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
loss = closure()
|
loss = closure()
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
eps = group["eps"]
|
do_batch = False
|
||||||
size_update_period = group["size_update_period"]
|
|
||||||
|
if len([ p if p.grad is None for p in group.params ]) != 0:
|
||||||
|
self._init_group_state(group)
|
||||||
|
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
@ -136,77 +138,136 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
|
|
||||||
# State initialization
|
# State initialization
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
|
self._init_state(group, p, state)
|
||||||
state["step"] = 0
|
if not do_batch:
|
||||||
|
self._step_one_param(group, p, state, batch=False)
|
||||||
kwargs = {'device':p.device, 'dtype':p.dtype}
|
if do_batch:
|
||||||
|
batches = self._get_batches(group)
|
||||||
|
|
||||||
# 'delta' implements conventional momentum. There are
|
|
||||||
# several different kinds of update going on, so rather than
|
|
||||||
# compute "exp_avg" like in Adam, we store and decay a
|
|
||||||
# parameter-change "delta", which combines all forms of
|
|
||||||
# update. this is equivalent to how it's done in Adam,
|
|
||||||
# except for the first few steps.
|
|
||||||
state["delta"] = torch.zeros_like(
|
|
||||||
p, memory_format=torch.preserve_format
|
|
||||||
)
|
|
||||||
|
|
||||||
if p.numel() > 1:
|
|
||||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
|
||||||
# the parameter tensor.
|
|
||||||
param_rms = (p**2).mean().sqrt().add_(eps)
|
|
||||||
state["param_rms"] = param_rms
|
|
||||||
|
|
||||||
state["scale_exp_avg_sq"] = torch.zeros((), **kwargs)
|
|
||||||
# scalar exp_avg_sq. not the same as scale_exp_avg_sq, this is used to
|
|
||||||
# determine the magnitude of the regular step
|
|
||||||
state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs)
|
|
||||||
state["scale_grads"] = torch.zeros(size_update_period,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
# exp_avg_sq is in the projected space. It is periodically zeroed, and we
|
|
||||||
# set zero_step.
|
|
||||||
state["exp_avg_sq"] = torch.zeros_like(
|
|
||||||
p, memory_format=torch.preserve_format
|
|
||||||
)
|
|
||||||
state["zero_step"] = 0
|
|
||||||
|
|
||||||
for dim in range(p.ndim):
|
|
||||||
size = p.shape[dim]
|
|
||||||
if size == 1 or size == p.numel():
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
|
||||||
# Q_{dim} is be the learning-rate matrix for this
|
|
||||||
# dimension, a matrix of indexed [diagonalized_coordinate, canonical_coordinate].
|
|
||||||
# this is used twice in the update step, once transposed and once without transpose.
|
|
||||||
state[f"Q_{dim}"] = torch.eye(size, **kwargs)
|
|
||||||
|
|
||||||
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
|
|
||||||
# all other dims as a batch axis.
|
|
||||||
state[f"param_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
|
||||||
|
|
||||||
# grad_cov_{dim} is the covariance of gradients on this axis (without
|
|
||||||
# any co-ordinate changes), treating all other axes as as a batch axis.
|
|
||||||
# This is needed when we re-diagonalize, and to compute the
|
|
||||||
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
|
|
||||||
|
|
||||||
# only for purposes of computing the scalar factor on the
|
|
||||||
# learning rate; and, as a result of this, also contributes something
|
|
||||||
# to the gradient w.r.t. f"Q_{dim}", which is one reason
|
|
||||||
# why we allocate a variable to keep track of its moving average
|
|
||||||
# instead of just using a temporary and smoothing the scalar factor.
|
|
||||||
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
|
||||||
|
|
||||||
self._step_one_param(group, p, state)
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def _get_batches(self,
|
||||||
|
group: dict):
|
||||||
|
"""
|
||||||
|
Arrange the parameters in the group into batches, and return a "fake" version of
|
||||||
|
group["params"] that contains the parameters stacked into batches, arranged by
|
||||||
|
size.
|
||||||
|
"""
|
||||||
|
is_initialized = False
|
||||||
|
if "batches" not in group:
|
||||||
|
self._init_batches(self, group)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_batches(self,
|
||||||
|
group: dict):
|
||||||
|
"""
|
||||||
|
Arrange the parameters in this parameter group by dtype and shape, and group
|
||||||
|
them into batches; this involves reallocating the tensors that are members
|
||||||
|
of self.state[p] for each param, so that there can be an underlying group
|
||||||
|
"""
|
||||||
|
|
||||||
|
batches = []
|
||||||
|
for (dtype,shape), params in itertools.groupby(group["params"],
|
||||||
|
lambda p: (p.dtype, p.shape)):
|
||||||
|
# (dtype,shape) is the "key" of the grouping, "params" is the list
|
||||||
|
# of parameters that share the dtype and shape. group["params"] is
|
||||||
|
# a list of all the parameters in this group.
|
||||||
|
#param_batches =
|
||||||
|
|
||||||
|
def _init_group_state(self,
|
||||||
|
group: dict,
|
||||||
|
batch: bool):
|
||||||
|
"""
|
||||||
|
Initializes state dict for parameter group `group`. Expects all elements of the state
|
||||||
|
to be uninitialized, i.e. you cannot have gradient for some parameters now and other
|
||||||
|
parameters later, it must be all or nothing. This is also a requirement for DDP,
|
||||||
|
so it's not much inconvenience.
|
||||||
|
Args:
|
||||||
|
group: Parameter group to initialize state for. A list of nn.Parameter.
|
||||||
|
batch: if True, we will arrange parameters into batches for more efficient
|
||||||
|
optimizatin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
eps = group["eps"]
|
||||||
|
size_update_period = group["size_update_period"]
|
||||||
|
|
||||||
|
for p in group["params"]:
|
||||||
|
assert p.grad is not None and "All parameters must have grads, or none."
|
||||||
|
grad = p.grad
|
||||||
|
if grad.is_sparse:
|
||||||
|
raise RuntimeError(
|
||||||
|
"PrAdam optimizer does not support sparse gradients"
|
||||||
|
)
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
|
state["step"] = 0
|
||||||
|
|
||||||
|
kwargs = {'device':p.device, 'dtype':p.dtype}
|
||||||
|
|
||||||
|
# 'delta' implements conventional momentum. There are
|
||||||
|
# several different kinds of update going on, so rather than
|
||||||
|
# compute "exp_avg" like in Adam, we store and decay a
|
||||||
|
# parameter-change "delta", which combines all forms of
|
||||||
|
# update. this is equivalent to how it's done in Adam,
|
||||||
|
# except for the first few steps.
|
||||||
|
state["delta"] = torch.zeros_like(
|
||||||
|
p, memory_format=torch.preserve_format
|
||||||
|
)
|
||||||
|
|
||||||
|
if p.numel() > 1:
|
||||||
|
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||||
|
# the parameter tensor.
|
||||||
|
param_rms = (p**2).mean().sqrt().add_(eps)
|
||||||
|
state["param_rms"] = param_rms
|
||||||
|
|
||||||
|
state["scale_exp_avg_sq"] = torch.zeros((), **kwargs)
|
||||||
|
# scalar exp_avg_sq. not the same as scale_exp_avg_sq, this is used to
|
||||||
|
# determine the magnitude of the regular step
|
||||||
|
state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs)
|
||||||
|
state["scale_grads"] = torch.zeros(size_update_period,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
# exp_avg_sq is in the projected space. It is periodically zeroed, and we
|
||||||
|
# set zero_step.
|
||||||
|
state["exp_avg_sq"] = torch.zeros_like(
|
||||||
|
p, memory_format=torch.preserve_format
|
||||||
|
)
|
||||||
|
state["zero_step"] = 0
|
||||||
|
|
||||||
|
for dim in range(p.ndim):
|
||||||
|
size = p.shape[dim]
|
||||||
|
if size == 1 or size == p.numel():
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
# Q_{dim} is be the learning-rate matrix for this
|
||||||
|
# dimension, a matrix of indexed [diagonalized_coordinate, canonical_coordinate].
|
||||||
|
# this is used twice in the update step, once transposed and once without transpose.
|
||||||
|
state[f"Q_{dim}"] = torch.eye(size, **kwargs)
|
||||||
|
|
||||||
|
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
|
||||||
|
# all other dims as a batch axis.
|
||||||
|
state[f"param_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||||
|
|
||||||
|
# grad_cov_{dim} is the covariance of gradients on this axis (without
|
||||||
|
# any co-ordinate changes), treating all other axes as as a batch axis.
|
||||||
|
# This is needed when we re-diagonalize, and to compute the
|
||||||
|
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
|
||||||
|
|
||||||
|
# only for purposes of computing the scalar factor on the
|
||||||
|
# learning rate; and, as a result of this, also contributes something
|
||||||
|
# to the gradient w.r.t. f"Q_{dim}", which is one reason
|
||||||
|
# why we allocate a variable to keep track of its moving average
|
||||||
|
# instead of just using a temporary and smoothing the scalar factor.
|
||||||
|
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _step_one_param(self,
|
def _step_one_param(self,
|
||||||
group: dict,
|
group: dict,
|
||||||
p: Tensor,
|
p: Tensor,
|
||||||
state: dict):
|
state: dict,
|
||||||
|
batch: bool):
|
||||||
lr = group["lr"]
|
lr = group["lr"]
|
||||||
size_lr = lr * group["size_lr_scale"]
|
size_lr = lr * group["size_lr_scale"]
|
||||||
beta1, beta2 = group["betas"]
|
beta1, beta2 = group["betas"]
|
||||||
@ -221,33 +282,37 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
step = state["step"]
|
step = state["step"]
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
delta.mul_(beta1)
|
delta.mul_(beta1)
|
||||||
numel = p.numel()
|
numel = p.numel() // (p.shape[0] if batch else 1)
|
||||||
if numel > 1:
|
if numel > 1:
|
||||||
# Update the size/scale of p, and set param_rms
|
# Update the size/scale of p, and set param_rms
|
||||||
scale_grads = state["scale_grads"]
|
scale_grads = state["scale_grads"]
|
||||||
scale_grads[step % size_update_period] = (p * grad).sum()
|
scale_grads[step % size_update_period] = _sum(p * grad,
|
||||||
|
exclude_dims=[0] if batch else [])
|
||||||
if step % size_update_period == size_update_period - 1:
|
if step % size_update_period == size_update_period - 1:
|
||||||
# this learns the overall scale on the parameter, by shrinking or
|
# this learns the overall scale on the parameter, by shrinking or
|
||||||
# expanding it.
|
# expanding it.
|
||||||
param_rms = state["param_rms"]
|
param_rms = state["param_rms"]
|
||||||
param_rms.copy_((p ** 2).mean().sqrt().clamp_(min=eps))
|
param_rms.copy_(_mean(p ** 2, exclude_dims=[0] if batch else []).sqrt().clamp_(min=eps))
|
||||||
if step > 0:
|
if step > 0:
|
||||||
self._size_update(p, state,
|
self._size_update(p, state,
|
||||||
scale_grads, param_rms,
|
scale_grads, param_rms,
|
||||||
beta1, beta2, step, size_lr,
|
beta1, beta2, step, size_lr,
|
||||||
param_min_rms, param_max_rms)
|
param_min_rms, param_max_rms,
|
||||||
|
batch)
|
||||||
|
|
||||||
if numel == 1:
|
if numel == 1:
|
||||||
# For parameters with very few elements we just use a form
|
# For parameters with very few elements we just use a form
|
||||||
# of Adam with a scale factor to reflect the overall
|
# of Adam with a scale factor to reflect the overall
|
||||||
# parameter rms. Updates delta.
|
# parameter rms. Updates delta.
|
||||||
self._step_scalar(scalar_max, beta1, beta2, eps, lr, p, grad, state)
|
self._step_scalar(scalar_max, beta1, beta2,
|
||||||
|
eps, lr, p, grad, state,
|
||||||
|
batch)
|
||||||
else:
|
else:
|
||||||
if step % lr_update_period == 0 and step > 0:
|
if step % lr_update_period == 0 and step > 0:
|
||||||
self._accum_param_covs(group, p, state)
|
self._accum_param_covs(group, p, state, batch)
|
||||||
self._update_lrs(group, p, state)
|
self._update_lrs(group, p, state, batch)
|
||||||
self._zero_exp_avg_sq(state)
|
self._zero_exp_avg_sq(state)
|
||||||
self._step(group, p, grad, state)
|
self._step(group, p, grad, state, batch)
|
||||||
p.add_(delta)
|
p.add_(delta)
|
||||||
state["step"] = step + 1
|
state["step"] = step + 1
|
||||||
|
|
||||||
@ -262,7 +327,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
step: int,
|
step: int,
|
||||||
size_lr: float,
|
size_lr: float,
|
||||||
param_min_rms: float,
|
param_min_rms: float,
|
||||||
param_max_rms: float) -> None:
|
param_max_rms: float,
|
||||||
|
batch: bool) -> None:
|
||||||
"""
|
"""
|
||||||
Called only where p.numel() > 1, this updates the scale of the parameter.
|
Called only where p.numel() > 1, this updates the scale of the parameter.
|
||||||
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
||||||
@ -281,11 +347,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
size_lr: The learning rate to apply to `scale`
|
size_lr: The learning rate to apply to `scale`
|
||||||
param_min_rms: User-specified minimum on the rms value of p, we will constrain it to >= this
|
param_min_rms: User-specified minimum on the rms value of p, we will constrain it to >= this
|
||||||
param_max_rms: User-specified maximum on the rms value of p, we will constrain it to <= this
|
param_max_rms: User-specified maximum on the rms value of p, we will constrain it to <= this
|
||||||
|
batch: If true, this is actually a batch of parameters stacked together,
|
||||||
|
and the 1st dim is the batch dim.
|
||||||
"""
|
"""
|
||||||
|
if batch:
|
||||||
# scale_grads is a tensor of shape (size_update_period,) containing a list of
|
batch_size, size_update_period = scale_grads.shape
|
||||||
# products (p * grad).sum() for the `size_update_period` most recent steps
|
else:
|
||||||
size_update_period, = scale_grads.shape
|
# scale_grads is a tensor of shape (size_update_period,) containing a list of
|
||||||
|
# products (p * grad).sum() for the `size_update_period` most recent steps
|
||||||
|
size_update_period, = scale_grads.shape
|
||||||
|
|
||||||
# correct beta2 for the size update period: we will have
|
# correct beta2 for the size update period: we will have
|
||||||
# faster decay at this level.
|
# faster decay at this level.
|
||||||
@ -293,7 +363,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
|
|
||||||
scale_exp_avg_sq = state["scale_exp_avg_sq"]
|
scale_exp_avg_sq = state["scale_exp_avg_sq"]
|
||||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||||
(scale_grads ** 2).mean(), alpha=1-beta2_corr)
|
_mean(scale_grads ** 2, exclude_dims=[0] if batch else []),
|
||||||
|
alpha=1-beta2_corr)
|
||||||
|
|
||||||
|
|
||||||
# The 1st time we reach here is when size_step == 1.
|
# The 1st time we reach here is when size_step == 1.
|
||||||
@ -305,7 +376,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
eps = 1.0e-10 # this value is not critical.
|
eps = 1.0e-10 # this value is not critical.
|
||||||
denom = scale_exp_avg_sq.sqrt() + eps
|
denom = scale_exp_avg_sq.sqrt() + eps
|
||||||
|
|
||||||
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom
|
scale_step = -size_lr * (bias_correction2 ** 0.5) * _sum(scale_grads,
|
||||||
|
exclude_dims=[0] if batch else [],
|
||||||
|
keepdim=True) / denom
|
||||||
|
|
||||||
is_too_small = (param_rms < param_min_rms)
|
is_too_small = (param_rms < param_min_rms)
|
||||||
is_too_large = (param_rms > param_max_rms)
|
is_too_large = (param_rms > param_max_rms)
|
||||||
@ -320,7 +393,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
def _accum_param_covs(self,
|
def _accum_param_covs(self,
|
||||||
group: dict,
|
group: dict,
|
||||||
p: Tensor,
|
p: Tensor,
|
||||||
state: dict) -> None:
|
state: dict,
|
||||||
|
batch: bool) -> None:
|
||||||
"""
|
"""
|
||||||
Updates state[f"param_cov_{dim}"] for each dim of tensor p.
|
Updates state[f"param_cov_{dim}"] for each dim of tensor p.
|
||||||
"""
|
"""
|
||||||
@ -329,20 +403,27 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
step = state["step"]
|
step = state["step"]
|
||||||
this_weight = (group["param_cov_freshness"] * lr_update_period /
|
this_weight = (group["param_cov_freshness"] * lr_update_period /
|
||||||
(step + lr_update_period))
|
(step + lr_update_period))
|
||||||
|
|
||||||
for dim in range(p.ndim):
|
for dim in range(p.ndim):
|
||||||
size = p.shape[dim]
|
size = p.shape[dim]
|
||||||
if size == 1 or size == p.numel():
|
if size == 1 or size == p.numel() or (batch and dim == 0):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# M is the parameter matrix, of shape (-1, size) where the -1 covers
|
# M is the parameter matrix, of shape (-1, size) where the -1 covers
|
||||||
# all other dimensions of the tensor.
|
# all other dimensions of the tensor.
|
||||||
M_full = p.transpose(dim, -1)
|
M_full = p.transpose(dim, -1)
|
||||||
M = M_full.reshape(-1, size)
|
if batch: # dim 0 is batch dim
|
||||||
|
M = M_full.reshape(p.shape[0], -1, size)
|
||||||
|
else:
|
||||||
|
M = M_full.reshape(-1, size)
|
||||||
|
# will be batched matrix multiply if batch==True
|
||||||
this_param_cov = torch.matmul(M.t(), M)
|
this_param_cov = torch.matmul(M.t(), M)
|
||||||
# normalize scale of this param_cov, in casre parameter scale
|
# normalize scale of this param_cov, in case parameter scale
|
||||||
# changes significantly during training, which would cause some
|
# changes significantly during training, which would cause some
|
||||||
# parts of the training timeline to be more highly weighted.
|
# parts of the training timeline to be more highly weighted.
|
||||||
this_param_cov /= (this_param_cov.diag().mean() + eps)
|
this_param_cov /= _mean(_diag(this_param_cov),
|
||||||
|
exclude_dims=[0] if batch else [],
|
||||||
|
keepdim=True) + eps
|
||||||
state[f"param_cov_{dim}"].mul_(1-this_weight).add_(this_param_cov,
|
state[f"param_cov_{dim}"].mul_(1-this_weight).add_(this_param_cov,
|
||||||
alpha=this_weight)
|
alpha=this_weight)
|
||||||
|
|
||||||
@ -359,7 +440,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
def _update_lrs(self,
|
def _update_lrs(self,
|
||||||
group: dict,
|
group: dict,
|
||||||
p: Tensor,
|
p: Tensor,
|
||||||
state: dict) -> None:
|
state: dict,
|
||||||
|
batch: bool) -> None:
|
||||||
"""
|
"""
|
||||||
Computes learning-rate matrices Q for each dim of this tensor.
|
Computes learning-rate matrices Q for each dim of this tensor.
|
||||||
Args:
|
Args:
|
||||||
@ -370,9 +452,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
state: state dict for the current parameter
|
state: state dict for the current parameter
|
||||||
"""
|
"""
|
||||||
ndim = p.ndim
|
ndim = p.ndim
|
||||||
numel = p.numel()
|
numel = p.numel() // (p.shape[0] if batch else 1)
|
||||||
|
|
||||||
if numel in p.shape:
|
if numel in (p[0].shape if batch else p.shape)
|
||||||
return # Nothing to do for this parameter matrix. E.g. a bias or a scalar.
|
return # Nothing to do for this parameter matrix. E.g. a bias or a scalar.
|
||||||
|
|
||||||
scale_arr = [None] * ndim
|
scale_arr = [None] * ndim
|
||||||
@ -382,21 +464,30 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
eps = 1.0e-20
|
eps = 1.0e-20
|
||||||
cur_p = p + eps * torch.randn_like(p)
|
cur_p = p + eps * torch.randn_like(p)
|
||||||
|
|
||||||
dims_to_sum = list(range(ndim-1))
|
|
||||||
|
|
||||||
for dim in range(ndim):
|
for dim in range(ndim):
|
||||||
size = p.shape[dim]
|
size = p.shape[dim]
|
||||||
if size == 1 or size == p.numel():
|
if size == 1 or size == p.numel() or (batch and dim == 0):
|
||||||
continue
|
continue
|
||||||
param_cov = state[f"param_cov_{dim}"]
|
param_cov = state[f"param_cov_{dim}"]
|
||||||
|
# if batch==True, U, S and V have an extra leading dim.
|
||||||
U, S, V = _svd(param_cov)
|
U, S, V = _svd(param_cov)
|
||||||
Q = state[f"Q_{dim}"]
|
Q = state[f"Q_{dim}"]
|
||||||
Q[:] = U.t()
|
Q[:] = U.transpose(-1,-2) # 3 dims if batch, else 2.
|
||||||
|
|
||||||
M = cur_p.transpose(dim, -1)
|
M = cur_p.transpose(dim, -1)
|
||||||
|
if batch:
|
||||||
|
while U.ndim < M.ndim:
|
||||||
|
U = U.unsqueeze(1)
|
||||||
N = torch.matmul(M, U)
|
N = torch.matmul(M, U)
|
||||||
cur_param_var = (N * N).mean(dim=dims_to_sum)
|
# cur_param_var is a diagonal parameter variance over dimension `dim`,
|
||||||
# OK, cur_param_var would be the same as S if the variance stats
|
# will have shape something like [1, size, 1]; or [batch_size, 1, size, 1]
|
||||||
|
# if batch == True.
|
||||||
|
cur_param_var = _mean(N**2,
|
||||||
|
exclude_dims=[0,dim] if batch else [dim],
|
||||||
|
keepdim=True)
|
||||||
|
S = S.reshape(cur_param_var.shape)
|
||||||
|
|
||||||
|
# OK, cur_param_var would have the values as S if the variance stats
|
||||||
# param_cov_{dim} were accumulated from this exact parameter matrix,
|
# param_cov_{dim} were accumulated from this exact parameter matrix,
|
||||||
# but actually they contain other versions of the parameter
|
# but actually they contain other versions of the parameter
|
||||||
# covariance so they will, in general, be less extreme. We scale
|
# covariance so they will, in general, be less extreme. We scale
|
||||||
@ -404,6 +495,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# ensure it doesn't have any too-small eigenvalues (where the
|
# ensure it doesn't have any too-small eigenvalues (where the
|
||||||
# stats permit).
|
# stats permit).
|
||||||
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
|
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
|
||||||
|
|
||||||
|
if batch:
|
||||||
|
batch_size = p.shape[0]
|
||||||
|
scale = scale.reshape(batch_size, 1, size)
|
||||||
|
else:
|
||||||
|
scale = scale.reshape(size)
|
||||||
N *= scale
|
N *= scale
|
||||||
cur_p = N.transpose(dim, -1)
|
cur_p = N.transpose(dim, -1)
|
||||||
|
|
||||||
@ -431,16 +528,17 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
for i in range(4): # for 4 iterations..
|
for i in range(4): # for 4 iterations..
|
||||||
for dim in range(ndim):
|
for dim in range(ndim):
|
||||||
size = p.shape[dim]
|
size = p.shape[dim]
|
||||||
if size == 1 or size == p.numel():
|
if size == 1 or size == p.numel() or (batch and dim == 0):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
M = cur_p.transpose(dim, -1)
|
M = cur_p.transpose(dim, -1)
|
||||||
# correct for the fact that we have already normalized this dim in cur_p.
|
# correct for the fact that we have already normalized this dim in cur_p.
|
||||||
if cur_scales[dim] is not None:
|
if cur_scales[dim] is not None:
|
||||||
M *= cur_scales[dim]
|
M *= cur_scales[dim]
|
||||||
rms = (M**2).mean(dim=dims_to_sum).sqrt()
|
rms = _mean(M**2,
|
||||||
|
exclude=[0,dim] if batch else dim).sqrt()
|
||||||
rank = numel // size
|
rank = numel // size
|
||||||
smoothed_rms = self._smooth_param_rms(group, rms, rank)
|
smoothed_rms = self._smooth_param_rms(group, rms, rank, batch)
|
||||||
cur_scales[dim] = smoothed_rms
|
cur_scales[dim] = smoothed_rms
|
||||||
M /= smoothed_rms # normalize this dim..
|
M /= smoothed_rms # normalize this dim..
|
||||||
|
|
||||||
@ -525,12 +623,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
Q[:] = torch.matmul(U.t(), Q)
|
Q[:] = torch.matmul(U.t(), Q)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _step(self,
|
def _step(self,
|
||||||
group: dict,
|
group: dict,
|
||||||
p: Tensor,
|
p: Tensor,
|
||||||
grad: Tensor,
|
grad: Tensor,
|
||||||
state: dict):
|
state: dict,
|
||||||
|
batch: bool):
|
||||||
"""
|
"""
|
||||||
This function does the core update of self._step, in the case where the tensor
|
This function does the core update of self._step, in the case where the tensor
|
||||||
has more than 1 elements. It multiplies the moving-average gradient tensor by a
|
has more than 1 elements. It multiplies the moving-average gradient tensor by a
|
||||||
@ -542,6 +640,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
p: The parameter to be updated
|
p: The parameter to be updated
|
||||||
grad: The grad of p
|
grad: The grad of p
|
||||||
state: The state-dict corresponding to parameter p
|
state: The state-dict corresponding to parameter p
|
||||||
|
batch: if true, the first dimension of p is a batch dim, i.e. a batch of
|
||||||
|
parameter matrices.
|
||||||
|
|
||||||
This function modifies p.
|
This function modifies p.
|
||||||
"""
|
"""
|
||||||
@ -557,17 +657,21 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# grad_cov_{dim}, which are for periodically updating the
|
# grad_cov_{dim}, which are for periodically updating the
|
||||||
# learning-rate matrices.
|
# learning-rate matrices.
|
||||||
size = grad.shape[dim]
|
size = grad.shape[dim]
|
||||||
if size == 1 or size == p.numel():
|
if size == 1 or size == p.numel() or (batch and dim == 0):
|
||||||
continue
|
continue
|
||||||
if step % grad_cov_period == 0 or step < grad_cov_period:
|
if step % grad_cov_period == 0 or step < grad_cov_period:
|
||||||
grad_cov = state[f"grad_cov_{dim}"]
|
grad_cov = state[f"grad_cov_{dim}"] # shaped: (batch, size, size) if batch, else (size, size)
|
||||||
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
|
if batch:
|
||||||
this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad
|
this_m = p.transpose(-1, dim).reshape(p.shape[0], -1, size) # parameter matrix M (batch)
|
||||||
|
this_g = grad.transpose(-1, dim).reshape(p.shape[0], -1, size) # M_grad (batch)
|
||||||
|
else:
|
||||||
|
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
|
||||||
|
this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad
|
||||||
# could perhaps accumulate grad_cov less frequently; it's only
|
# could perhaps accumulate grad_cov less frequently; it's only
|
||||||
# needed when we rediagonalize which is not that common.
|
# needed when we rediagonalize which is not that common.
|
||||||
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
|
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
|
||||||
|
|
||||||
grad = self._project(grad, state, forward=True)
|
grad = self._project(grad, state, batch, forward=True)
|
||||||
|
|
||||||
exp_avg_sq = state["exp_avg_sq"]
|
exp_avg_sq = state["exp_avg_sq"]
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||||
@ -583,10 +687,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
grad = grad / denom
|
grad = grad / denom
|
||||||
|
|
||||||
# and project back..
|
# and project back..
|
||||||
grad = self._project(grad, state, forward=False)
|
grad = self._project(grad, state, batch, forward=False)
|
||||||
|
|
||||||
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
|
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
|
||||||
scalar_exp_avg_sq.mul_(beta2).add_((grad**2).mean(),
|
scalar_exp_avg_sq.mul_(beta2).add_(_mean(grad**2,
|
||||||
|
exclude_dims=[0] if batch else [],
|
||||||
|
keepdim=batch),
|
||||||
alpha=1-beta2)
|
alpha=1-beta2)
|
||||||
|
|
||||||
|
|
||||||
@ -595,12 +701,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# use the same bias_correction2.
|
# use the same bias_correction2.
|
||||||
scalar_exp_avg_sq = scalar_exp_avg_sq * (1.0 / bias_correction2)
|
scalar_exp_avg_sq = scalar_exp_avg_sq * (1.0 / bias_correction2)
|
||||||
|
|
||||||
denom = scalar_exp_avg_sq.sqrt() + eps # scalar.
|
denom = scalar_exp_avg_sq.sqrt() + eps # scalar, or [batch_size,1,1..]
|
||||||
|
|
||||||
alpha = state["param_rms"] * (1-beta1) * -lr / denom
|
alpha = state["param_rms"] * (1-beta1) * -lr / denom
|
||||||
|
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
delta.add_(grad, alpha=alpha)
|
if batch:
|
||||||
|
delta.add_(grad * alpha)
|
||||||
|
else:
|
||||||
|
delta.add_(grad, alpha=alpha)
|
||||||
|
|
||||||
|
|
||||||
def _step_scalar(self,
|
def _step_scalar(self,
|
||||||
@ -611,19 +720,20 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
lr: float,
|
lr: float,
|
||||||
p: Tensor,
|
p: Tensor,
|
||||||
grad: Tensor,
|
grad: Tensor,
|
||||||
state: dict):
|
state: dict,
|
||||||
|
batch: bool):
|
||||||
"""
|
"""
|
||||||
A form of the core update for scalar tensors, where we cannot get a good
|
A form of the core update for scalar tensors, where we cannot get a good
|
||||||
estimate of the parameter rms.
|
estimate of the parameter rms.
|
||||||
"""
|
"""
|
||||||
exp_avg_sq = state["exp_avg_sq"]
|
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) if batch else (0
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||||
value=1-beta2)
|
value=1-beta2)
|
||||||
|
|
||||||
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
||||||
# slower update at the start will help stability anyway.
|
# slower update at the start will help stability anyway.
|
||||||
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
||||||
denom = (exp_avg_sq.sum() / (exp_avg_sq.numel() * bias_correction2)).sqrt() + eps
|
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
||||||
alpha = -lr * (1-beta1) / denom
|
alpha = -lr * (1-beta1) / denom
|
||||||
|
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
@ -633,6 +743,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
def _project(self,
|
def _project(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
state: dict,
|
state: dict,
|
||||||
|
batch: bool,
|
||||||
forward: bool) -> Tensor:
|
forward: bool) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Multiply a tensor x by proj_{dim} on each of its dimensions.
|
Multiply a tensor x by proj_{dim} on each of its dimensions.
|
||||||
@ -641,6 +752,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: The tensor to project, e.g. a gradient or a parameter change.
|
x: The tensor to project, e.g. a gradient or a parameter change.
|
||||||
|
state: dict to look up the projections
|
||||||
|
batch: if True, the 1st dim of x is a batch dim and is not projected.
|
||||||
forward: if True, go in the forward directin (from canonical to diagnonalized
|
forward: if True, go in the forward directin (from canonical to diagnonalized
|
||||||
co-ordinates); if False, the reverse. They differ by a transpose,
|
co-ordinates); if False, the reverse. They differ by a transpose,
|
||||||
not an inverse, and do not make a round trip.
|
not an inverse, and do not make a round trip.
|
||||||
@ -648,7 +761,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
numel = x.numel()
|
numel = x.numel()
|
||||||
for dim in range(x.ndim):
|
for dim in range(x.ndim):
|
||||||
size = x.shape[dim]
|
size = x.shape[dim]
|
||||||
if size == 1 or size == numel:
|
if size == 1 or size == numel or (batch and dim == 0):
|
||||||
continue
|
continue
|
||||||
Q = state[f"Q_{dim}"]
|
Q = state[f"Q_{dim}"]
|
||||||
if forward:
|
if forward:
|
||||||
@ -669,12 +782,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
rank: int) -> Tensor:
|
rank: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Smooths and normalizes to mean=1 a tensor of diagonalized parameter rms values for one dimension
|
Smooths and normalizes to mean=1 a tensor of diagonalized parameter rms values for one dimension
|
||||||
of a tensor. This will be used to construct a learning-rate matrix.
|
of a tensor, or a batch of such diagonalized rms values.
|
||||||
|
This will be used to construct a learning-rate matrix.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group: dict for configuration values
|
group: dict for configuration values
|
||||||
rms: a tensor with one dimension, representing diagonalized parameter
|
rms: Either a tensor with shape (size,) representing diagonalized parameter
|
||||||
root-mean-square values.
|
root-mean-square values, or a tensor of shape (batch_size, size).
|
||||||
rank: the assumed rank of the covariance matrix from which rms was derived,
|
rank: the assumed rank of the covariance matrix from which rms was derived,
|
||||||
used to decide how much smoothing to apply. This is actually the
|
used to decide how much smoothing to apply. This is actually the
|
||||||
minimal rank, if the parameter matrix were stay fixed during training,
|
minimal rank, if the parameter matrix were stay fixed during training,
|
||||||
@ -688,7 +802,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
rms = rms ** param_pow
|
rms = rms ** param_pow
|
||||||
smooth = (smooth0 +
|
smooth = (smooth0 +
|
||||||
(smooth1 - smooth0) * size / (size + rank))
|
(smooth1 - smooth0) * size / (size + rank))
|
||||||
mean = rms.mean()
|
batch = (rms.ndim > 1)
|
||||||
|
mean = _mean(rms,
|
||||||
|
exclude_dim=0 if batch else None,
|
||||||
|
keepdim=True)
|
||||||
rms += eps + smooth * mean
|
rms += eps + smooth * mean
|
||||||
new_mean = (eps + (smooth + 1) * mean)
|
new_mean = (eps + (smooth + 1) * mean)
|
||||||
ans = rms / new_mean
|
ans = rms / new_mean
|
||||||
@ -697,12 +814,58 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# Apply max_rms
|
# Apply max_rms
|
||||||
max_rms = 5.0
|
max_rms = 5.0
|
||||||
ans.clamp_(max=max_rms*2)
|
ans.clamp_(max=max_rms*2)
|
||||||
ans /= ans.mean()
|
ans /= _mean(ans, exclude_dim=0 if batch else None,
|
||||||
|
keepdim=True)
|
||||||
ans.clamp_(max=max_rms)
|
ans.clamp_(max=max_rms)
|
||||||
ans /= ans.mean()
|
ans /= _mean(ans, exclude_dim=0 if batch else None,
|
||||||
|
keepdim=True)
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def _diag(x: Tensor):
|
||||||
|
"""
|
||||||
|
like torch diag(), but supports batch dim, i.e. input of shape (B, M, M) returns
|
||||||
|
output of shape (B, M)
|
||||||
|
"""
|
||||||
|
if x.ndim == 3:
|
||||||
|
(B, M, M2) = x.shape
|
||||||
|
assert M == M2
|
||||||
|
stride = x.stride()
|
||||||
|
return x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2])).contiguous()
|
||||||
|
else:
|
||||||
|
return x.diag()
|
||||||
|
|
||||||
|
def _sum(x: Tensor,
|
||||||
|
exclude_dims: List[int] = [],
|
||||||
|
keepdim: bool = False):
|
||||||
|
"""
|
||||||
|
Version of torch.sum where you specify dims to exclude, not dims to sum.
|
||||||
|
"""
|
||||||
|
if len(exclude_dim) == 0:
|
||||||
|
return x.sum(keepdim=keepdim)
|
||||||
|
elif x.ndim == 1:
|
||||||
|
return x # if one dim is excluded, there are no dims to sum, and
|
||||||
|
# x.sum(dim=[]) sums all dims so we should not call sum().
|
||||||
|
exclude_dims = [i if i >= 0 else i + x.ndim for i in exclude_dims]
|
||||||
|
dims = [i for i in range(x.ndim) if i not in exclude_dim]
|
||||||
|
return x.sum(dim=dims, keepdim=keepdim)
|
||||||
|
|
||||||
|
def _mean(x: Tensor,
|
||||||
|
exclude_dims: List[int] = [],
|
||||||
|
keepdim: bool = False):
|
||||||
|
"""
|
||||||
|
Version of torch.mean where you specify dims to exclude, not dims to mean.
|
||||||
|
"""
|
||||||
|
if len(exclude_dim) == 0:
|
||||||
|
return x.mean(keepdim=keepdim)
|
||||||
|
elif x.ndim == 1:
|
||||||
|
return x # if one dim is excluded, there are no dims to mean, and
|
||||||
|
# x.mean(dim=[]) averages all dims so we should not call mean().
|
||||||
|
exclude_dims = [i if i >= 0 else i + x.ndim for i in exclude_dims]
|
||||||
|
dims = [i for i in range(x.ndim) if i not in exclude_dim]
|
||||||
|
return x.mean(dim=dims, keepdim=keepdim)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user