mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Committing partial work..
This commit is contained in:
parent
d25df4af5e
commit
27da50a1f6
@ -119,8 +119,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
eps = group["eps"]
|
||||
size_update_period = group["size_update_period"]
|
||||
do_batch = False
|
||||
|
||||
if len([ p if p.grad is None for p in group.params ]) != 0:
|
||||
self._init_group_state(group)
|
||||
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
@ -136,77 +138,136 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
|
||||
state["step"] = 0
|
||||
|
||||
kwargs = {'device':p.device, 'dtype':p.dtype}
|
||||
|
||||
|
||||
# 'delta' implements conventional momentum. There are
|
||||
# several different kinds of update going on, so rather than
|
||||
# compute "exp_avg" like in Adam, we store and decay a
|
||||
# parameter-change "delta", which combines all forms of
|
||||
# update. this is equivalent to how it's done in Adam,
|
||||
# except for the first few steps.
|
||||
state["delta"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
if p.numel() > 1:
|
||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||
# the parameter tensor.
|
||||
param_rms = (p**2).mean().sqrt().add_(eps)
|
||||
state["param_rms"] = param_rms
|
||||
|
||||
state["scale_exp_avg_sq"] = torch.zeros((), **kwargs)
|
||||
# scalar exp_avg_sq. not the same as scale_exp_avg_sq, this is used to
|
||||
# determine the magnitude of the regular step
|
||||
state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs)
|
||||
state["scale_grads"] = torch.zeros(size_update_period,
|
||||
**kwargs)
|
||||
|
||||
# exp_avg_sq is in the projected space. It is periodically zeroed, and we
|
||||
# set zero_step.
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
state["zero_step"] = 0
|
||||
|
||||
for dim in range(p.ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1 or size == p.numel():
|
||||
continue
|
||||
|
||||
|
||||
# Q_{dim} is be the learning-rate matrix for this
|
||||
# dimension, a matrix of indexed [diagonalized_coordinate, canonical_coordinate].
|
||||
# this is used twice in the update step, once transposed and once without transpose.
|
||||
state[f"Q_{dim}"] = torch.eye(size, **kwargs)
|
||||
|
||||
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
|
||||
# all other dims as a batch axis.
|
||||
state[f"param_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||
|
||||
# grad_cov_{dim} is the covariance of gradients on this axis (without
|
||||
# any co-ordinate changes), treating all other axes as as a batch axis.
|
||||
# This is needed when we re-diagonalize, and to compute the
|
||||
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
|
||||
|
||||
# only for purposes of computing the scalar factor on the
|
||||
# learning rate; and, as a result of this, also contributes something
|
||||
# to the gradient w.r.t. f"Q_{dim}", which is one reason
|
||||
# why we allocate a variable to keep track of its moving average
|
||||
# instead of just using a temporary and smoothing the scalar factor.
|
||||
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||
|
||||
self._step_one_param(group, p, state)
|
||||
self._init_state(group, p, state)
|
||||
if not do_batch:
|
||||
self._step_one_param(group, p, state, batch=False)
|
||||
if do_batch:
|
||||
batches = self._get_batches(group)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def _get_batches(self,
|
||||
group: dict):
|
||||
"""
|
||||
Arrange the parameters in the group into batches, and return a "fake" version of
|
||||
group["params"] that contains the parameters stacked into batches, arranged by
|
||||
size.
|
||||
"""
|
||||
is_initialized = False
|
||||
if "batches" not in group:
|
||||
self._init_batches(self, group)
|
||||
|
||||
|
||||
def _init_batches(self,
|
||||
group: dict):
|
||||
"""
|
||||
Arrange the parameters in this parameter group by dtype and shape, and group
|
||||
them into batches; this involves reallocating the tensors that are members
|
||||
of self.state[p] for each param, so that there can be an underlying group
|
||||
"""
|
||||
|
||||
batches = []
|
||||
for (dtype,shape), params in itertools.groupby(group["params"],
|
||||
lambda p: (p.dtype, p.shape)):
|
||||
# (dtype,shape) is the "key" of the grouping, "params" is the list
|
||||
# of parameters that share the dtype and shape. group["params"] is
|
||||
# a list of all the parameters in this group.
|
||||
#param_batches =
|
||||
|
||||
def _init_group_state(self,
|
||||
group: dict,
|
||||
batch: bool):
|
||||
"""
|
||||
Initializes state dict for parameter group `group`. Expects all elements of the state
|
||||
to be uninitialized, i.e. you cannot have gradient for some parameters now and other
|
||||
parameters later, it must be all or nothing. This is also a requirement for DDP,
|
||||
so it's not much inconvenience.
|
||||
Args:
|
||||
group: Parameter group to initialize state for. A list of nn.Parameter.
|
||||
batch: if True, we will arrange parameters into batches for more efficient
|
||||
optimizatin.
|
||||
"""
|
||||
|
||||
eps = group["eps"]
|
||||
size_update_period = group["size_update_period"]
|
||||
|
||||
for p in group["params"]:
|
||||
assert p.grad is not None and "All parameters must have grads, or none."
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"PrAdam optimizer does not support sparse gradients"
|
||||
)
|
||||
state = self.state[p]
|
||||
|
||||
state["step"] = 0
|
||||
|
||||
kwargs = {'device':p.device, 'dtype':p.dtype}
|
||||
|
||||
# 'delta' implements conventional momentum. There are
|
||||
# several different kinds of update going on, so rather than
|
||||
# compute "exp_avg" like in Adam, we store and decay a
|
||||
# parameter-change "delta", which combines all forms of
|
||||
# update. this is equivalent to how it's done in Adam,
|
||||
# except for the first few steps.
|
||||
state["delta"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
if p.numel() > 1:
|
||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||
# the parameter tensor.
|
||||
param_rms = (p**2).mean().sqrt().add_(eps)
|
||||
state["param_rms"] = param_rms
|
||||
|
||||
state["scale_exp_avg_sq"] = torch.zeros((), **kwargs)
|
||||
# scalar exp_avg_sq. not the same as scale_exp_avg_sq, this is used to
|
||||
# determine the magnitude of the regular step
|
||||
state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs)
|
||||
state["scale_grads"] = torch.zeros(size_update_period,
|
||||
**kwargs)
|
||||
|
||||
# exp_avg_sq is in the projected space. It is periodically zeroed, and we
|
||||
# set zero_step.
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
state["zero_step"] = 0
|
||||
|
||||
for dim in range(p.ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1 or size == p.numel():
|
||||
continue
|
||||
|
||||
|
||||
# Q_{dim} is be the learning-rate matrix for this
|
||||
# dimension, a matrix of indexed [diagonalized_coordinate, canonical_coordinate].
|
||||
# this is used twice in the update step, once transposed and once without transpose.
|
||||
state[f"Q_{dim}"] = torch.eye(size, **kwargs)
|
||||
|
||||
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
|
||||
# all other dims as a batch axis.
|
||||
state[f"param_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||
|
||||
# grad_cov_{dim} is the covariance of gradients on this axis (without
|
||||
# any co-ordinate changes), treating all other axes as as a batch axis.
|
||||
# This is needed when we re-diagonalize, and to compute the
|
||||
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
|
||||
|
||||
# only for purposes of computing the scalar factor on the
|
||||
# learning rate; and, as a result of this, also contributes something
|
||||
# to the gradient w.r.t. f"Q_{dim}", which is one reason
|
||||
# why we allocate a variable to keep track of its moving average
|
||||
# instead of just using a temporary and smoothing the scalar factor.
|
||||
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||
|
||||
|
||||
def _step_one_param(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict):
|
||||
state: dict,
|
||||
batch: bool):
|
||||
lr = group["lr"]
|
||||
size_lr = lr * group["size_lr_scale"]
|
||||
beta1, beta2 = group["betas"]
|
||||
@ -221,33 +282,37 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
step = state["step"]
|
||||
delta = state["delta"]
|
||||
delta.mul_(beta1)
|
||||
numel = p.numel()
|
||||
numel = p.numel() // (p.shape[0] if batch else 1)
|
||||
if numel > 1:
|
||||
# Update the size/scale of p, and set param_rms
|
||||
scale_grads = state["scale_grads"]
|
||||
scale_grads[step % size_update_period] = (p * grad).sum()
|
||||
scale_grads[step % size_update_period] = _sum(p * grad,
|
||||
exclude_dims=[0] if batch else [])
|
||||
if step % size_update_period == size_update_period - 1:
|
||||
# this learns the overall scale on the parameter, by shrinking or
|
||||
# expanding it.
|
||||
param_rms = state["param_rms"]
|
||||
param_rms.copy_((p ** 2).mean().sqrt().clamp_(min=eps))
|
||||
param_rms.copy_(_mean(p ** 2, exclude_dims=[0] if batch else []).sqrt().clamp_(min=eps))
|
||||
if step > 0:
|
||||
self._size_update(p, state,
|
||||
scale_grads, param_rms,
|
||||
beta1, beta2, step, size_lr,
|
||||
param_min_rms, param_max_rms)
|
||||
param_min_rms, param_max_rms,
|
||||
batch)
|
||||
|
||||
if numel == 1:
|
||||
# For parameters with very few elements we just use a form
|
||||
# of Adam with a scale factor to reflect the overall
|
||||
# parameter rms. Updates delta.
|
||||
self._step_scalar(scalar_max, beta1, beta2, eps, lr, p, grad, state)
|
||||
self._step_scalar(scalar_max, beta1, beta2,
|
||||
eps, lr, p, grad, state,
|
||||
batch)
|
||||
else:
|
||||
if step % lr_update_period == 0 and step > 0:
|
||||
self._accum_param_covs(group, p, state)
|
||||
self._update_lrs(group, p, state)
|
||||
self._accum_param_covs(group, p, state, batch)
|
||||
self._update_lrs(group, p, state, batch)
|
||||
self._zero_exp_avg_sq(state)
|
||||
self._step(group, p, grad, state)
|
||||
self._step(group, p, grad, state, batch)
|
||||
p.add_(delta)
|
||||
state["step"] = step + 1
|
||||
|
||||
@ -262,7 +327,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
step: int,
|
||||
size_lr: float,
|
||||
param_min_rms: float,
|
||||
param_max_rms: float) -> None:
|
||||
param_max_rms: float,
|
||||
batch: bool) -> None:
|
||||
"""
|
||||
Called only where p.numel() > 1, this updates the scale of the parameter.
|
||||
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
||||
@ -281,11 +347,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
size_lr: The learning rate to apply to `scale`
|
||||
param_min_rms: User-specified minimum on the rms value of p, we will constrain it to >= this
|
||||
param_max_rms: User-specified maximum on the rms value of p, we will constrain it to <= this
|
||||
batch: If true, this is actually a batch of parameters stacked together,
|
||||
and the 1st dim is the batch dim.
|
||||
"""
|
||||
|
||||
# scale_grads is a tensor of shape (size_update_period,) containing a list of
|
||||
# products (p * grad).sum() for the `size_update_period` most recent steps
|
||||
size_update_period, = scale_grads.shape
|
||||
if batch:
|
||||
batch_size, size_update_period = scale_grads.shape
|
||||
else:
|
||||
# scale_grads is a tensor of shape (size_update_period,) containing a list of
|
||||
# products (p * grad).sum() for the `size_update_period` most recent steps
|
||||
size_update_period, = scale_grads.shape
|
||||
|
||||
# correct beta2 for the size update period: we will have
|
||||
# faster decay at this level.
|
||||
@ -293,7 +363,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
|
||||
scale_exp_avg_sq = state["scale_exp_avg_sq"]
|
||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||
(scale_grads ** 2).mean(), alpha=1-beta2_corr)
|
||||
_mean(scale_grads ** 2, exclude_dims=[0] if batch else []),
|
||||
alpha=1-beta2_corr)
|
||||
|
||||
|
||||
# The 1st time we reach here is when size_step == 1.
|
||||
@ -305,7 +376,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
eps = 1.0e-10 # this value is not critical.
|
||||
denom = scale_exp_avg_sq.sqrt() + eps
|
||||
|
||||
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom
|
||||
scale_step = -size_lr * (bias_correction2 ** 0.5) * _sum(scale_grads,
|
||||
exclude_dims=[0] if batch else [],
|
||||
keepdim=True) / denom
|
||||
|
||||
is_too_small = (param_rms < param_min_rms)
|
||||
is_too_large = (param_rms > param_max_rms)
|
||||
@ -320,7 +393,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
def _accum_param_covs(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict) -> None:
|
||||
state: dict,
|
||||
batch: bool) -> None:
|
||||
"""
|
||||
Updates state[f"param_cov_{dim}"] for each dim of tensor p.
|
||||
"""
|
||||
@ -329,20 +403,27 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
step = state["step"]
|
||||
this_weight = (group["param_cov_freshness"] * lr_update_period /
|
||||
(step + lr_update_period))
|
||||
|
||||
for dim in range(p.ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1 or size == p.numel():
|
||||
if size == 1 or size == p.numel() or (batch and dim == 0):
|
||||
continue
|
||||
|
||||
# M is the parameter matrix, of shape (-1, size) where the -1 covers
|
||||
# all other dimensions of the tensor.
|
||||
M_full = p.transpose(dim, -1)
|
||||
M = M_full.reshape(-1, size)
|
||||
if batch: # dim 0 is batch dim
|
||||
M = M_full.reshape(p.shape[0], -1, size)
|
||||
else:
|
||||
M = M_full.reshape(-1, size)
|
||||
# will be batched matrix multiply if batch==True
|
||||
this_param_cov = torch.matmul(M.t(), M)
|
||||
# normalize scale of this param_cov, in casre parameter scale
|
||||
# normalize scale of this param_cov, in case parameter scale
|
||||
# changes significantly during training, which would cause some
|
||||
# parts of the training timeline to be more highly weighted.
|
||||
this_param_cov /= (this_param_cov.diag().mean() + eps)
|
||||
this_param_cov /= _mean(_diag(this_param_cov),
|
||||
exclude_dims=[0] if batch else [],
|
||||
keepdim=True) + eps
|
||||
state[f"param_cov_{dim}"].mul_(1-this_weight).add_(this_param_cov,
|
||||
alpha=this_weight)
|
||||
|
||||
@ -359,7 +440,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
def _update_lrs(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict) -> None:
|
||||
state: dict,
|
||||
batch: bool) -> None:
|
||||
"""
|
||||
Computes learning-rate matrices Q for each dim of this tensor.
|
||||
Args:
|
||||
@ -370,9 +452,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
state: state dict for the current parameter
|
||||
"""
|
||||
ndim = p.ndim
|
||||
numel = p.numel()
|
||||
numel = p.numel() // (p.shape[0] if batch else 1)
|
||||
|
||||
if numel in p.shape:
|
||||
if numel in (p[0].shape if batch else p.shape)
|
||||
return # Nothing to do for this parameter matrix. E.g. a bias or a scalar.
|
||||
|
||||
scale_arr = [None] * ndim
|
||||
@ -382,21 +464,30 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
eps = 1.0e-20
|
||||
cur_p = p + eps * torch.randn_like(p)
|
||||
|
||||
dims_to_sum = list(range(ndim-1))
|
||||
|
||||
for dim in range(ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1 or size == p.numel():
|
||||
if size == 1 or size == p.numel() or (batch and dim == 0):
|
||||
continue
|
||||
param_cov = state[f"param_cov_{dim}"]
|
||||
# if batch==True, U, S and V have an extra leading dim.
|
||||
U, S, V = _svd(param_cov)
|
||||
Q = state[f"Q_{dim}"]
|
||||
Q[:] = U.t()
|
||||
Q[:] = U.transpose(-1,-2) # 3 dims if batch, else 2.
|
||||
|
||||
M = cur_p.transpose(dim, -1)
|
||||
if batch:
|
||||
while U.ndim < M.ndim:
|
||||
U = U.unsqueeze(1)
|
||||
N = torch.matmul(M, U)
|
||||
cur_param_var = (N * N).mean(dim=dims_to_sum)
|
||||
# OK, cur_param_var would be the same as S if the variance stats
|
||||
# cur_param_var is a diagonal parameter variance over dimension `dim`,
|
||||
# will have shape something like [1, size, 1]; or [batch_size, 1, size, 1]
|
||||
# if batch == True.
|
||||
cur_param_var = _mean(N**2,
|
||||
exclude_dims=[0,dim] if batch else [dim],
|
||||
keepdim=True)
|
||||
S = S.reshape(cur_param_var.shape)
|
||||
|
||||
# OK, cur_param_var would have the values as S if the variance stats
|
||||
# param_cov_{dim} were accumulated from this exact parameter matrix,
|
||||
# but actually they contain other versions of the parameter
|
||||
# covariance so they will, in general, be less extreme. We scale
|
||||
@ -404,6 +495,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# ensure it doesn't have any too-small eigenvalues (where the
|
||||
# stats permit).
|
||||
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
|
||||
|
||||
if batch:
|
||||
batch_size = p.shape[0]
|
||||
scale = scale.reshape(batch_size, 1, size)
|
||||
else:
|
||||
scale = scale.reshape(size)
|
||||
N *= scale
|
||||
cur_p = N.transpose(dim, -1)
|
||||
|
||||
@ -431,16 +528,17 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
for i in range(4): # for 4 iterations..
|
||||
for dim in range(ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1 or size == p.numel():
|
||||
if size == 1 or size == p.numel() or (batch and dim == 0):
|
||||
continue
|
||||
|
||||
M = cur_p.transpose(dim, -1)
|
||||
# correct for the fact that we have already normalized this dim in cur_p.
|
||||
if cur_scales[dim] is not None:
|
||||
M *= cur_scales[dim]
|
||||
rms = (M**2).mean(dim=dims_to_sum).sqrt()
|
||||
rms = _mean(M**2,
|
||||
exclude=[0,dim] if batch else dim).sqrt()
|
||||
rank = numel // size
|
||||
smoothed_rms = self._smooth_param_rms(group, rms, rank)
|
||||
smoothed_rms = self._smooth_param_rms(group, rms, rank, batch)
|
||||
cur_scales[dim] = smoothed_rms
|
||||
M /= smoothed_rms # normalize this dim..
|
||||
|
||||
@ -525,12 +623,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
Q[:] = torch.matmul(U.t(), Q)
|
||||
|
||||
|
||||
|
||||
def _step(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
grad: Tensor,
|
||||
state: dict):
|
||||
state: dict,
|
||||
batch: bool):
|
||||
"""
|
||||
This function does the core update of self._step, in the case where the tensor
|
||||
has more than 1 elements. It multiplies the moving-average gradient tensor by a
|
||||
@ -542,6 +640,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
p: The parameter to be updated
|
||||
grad: The grad of p
|
||||
state: The state-dict corresponding to parameter p
|
||||
batch: if true, the first dimension of p is a batch dim, i.e. a batch of
|
||||
parameter matrices.
|
||||
|
||||
This function modifies p.
|
||||
"""
|
||||
@ -557,17 +657,21 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# grad_cov_{dim}, which are for periodically updating the
|
||||
# learning-rate matrices.
|
||||
size = grad.shape[dim]
|
||||
if size == 1 or size == p.numel():
|
||||
if size == 1 or size == p.numel() or (batch and dim == 0):
|
||||
continue
|
||||
if step % grad_cov_period == 0 or step < grad_cov_period:
|
||||
grad_cov = state[f"grad_cov_{dim}"]
|
||||
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
|
||||
this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad
|
||||
grad_cov = state[f"grad_cov_{dim}"] # shaped: (batch, size, size) if batch, else (size, size)
|
||||
if batch:
|
||||
this_m = p.transpose(-1, dim).reshape(p.shape[0], -1, size) # parameter matrix M (batch)
|
||||
this_g = grad.transpose(-1, dim).reshape(p.shape[0], -1, size) # M_grad (batch)
|
||||
else:
|
||||
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
|
||||
this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad
|
||||
# could perhaps accumulate grad_cov less frequently; it's only
|
||||
# needed when we rediagonalize which is not that common.
|
||||
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
|
||||
|
||||
grad = self._project(grad, state, forward=True)
|
||||
grad = self._project(grad, state, batch, forward=True)
|
||||
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||
@ -583,10 +687,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
grad = grad / denom
|
||||
|
||||
# and project back..
|
||||
grad = self._project(grad, state, forward=False)
|
||||
grad = self._project(grad, state, batch, forward=False)
|
||||
|
||||
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
|
||||
scalar_exp_avg_sq.mul_(beta2).add_((grad**2).mean(),
|
||||
scalar_exp_avg_sq.mul_(beta2).add_(_mean(grad**2,
|
||||
exclude_dims=[0] if batch else [],
|
||||
keepdim=batch),
|
||||
alpha=1-beta2)
|
||||
|
||||
|
||||
@ -595,12 +701,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# use the same bias_correction2.
|
||||
scalar_exp_avg_sq = scalar_exp_avg_sq * (1.0 / bias_correction2)
|
||||
|
||||
denom = scalar_exp_avg_sq.sqrt() + eps # scalar.
|
||||
denom = scalar_exp_avg_sq.sqrt() + eps # scalar, or [batch_size,1,1..]
|
||||
|
||||
alpha = state["param_rms"] * (1-beta1) * -lr / denom
|
||||
|
||||
delta = state["delta"]
|
||||
delta.add_(grad, alpha=alpha)
|
||||
if batch:
|
||||
delta.add_(grad * alpha)
|
||||
else:
|
||||
delta.add_(grad, alpha=alpha)
|
||||
|
||||
|
||||
def _step_scalar(self,
|
||||
@ -611,19 +720,20 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
lr: float,
|
||||
p: Tensor,
|
||||
grad: Tensor,
|
||||
state: dict):
|
||||
state: dict,
|
||||
batch: bool):
|
||||
"""
|
||||
A form of the core update for scalar tensors, where we cannot get a good
|
||||
estimate of the parameter rms.
|
||||
"""
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) if batch else (0
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||
value=1-beta2)
|
||||
|
||||
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
||||
# slower update at the start will help stability anyway.
|
||||
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
||||
denom = (exp_avg_sq.sum() / (exp_avg_sq.numel() * bias_correction2)).sqrt() + eps
|
||||
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
||||
alpha = -lr * (1-beta1) / denom
|
||||
|
||||
delta = state["delta"]
|
||||
@ -633,6 +743,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
def _project(self,
|
||||
x: Tensor,
|
||||
state: dict,
|
||||
batch: bool,
|
||||
forward: bool) -> Tensor:
|
||||
"""
|
||||
Multiply a tensor x by proj_{dim} on each of its dimensions.
|
||||
@ -641,6 +752,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
|
||||
Args:
|
||||
x: The tensor to project, e.g. a gradient or a parameter change.
|
||||
state: dict to look up the projections
|
||||
batch: if True, the 1st dim of x is a batch dim and is not projected.
|
||||
forward: if True, go in the forward directin (from canonical to diagnonalized
|
||||
co-ordinates); if False, the reverse. They differ by a transpose,
|
||||
not an inverse, and do not make a round trip.
|
||||
@ -648,7 +761,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
numel = x.numel()
|
||||
for dim in range(x.ndim):
|
||||
size = x.shape[dim]
|
||||
if size == 1 or size == numel:
|
||||
if size == 1 or size == numel or (batch and dim == 0):
|
||||
continue
|
||||
Q = state[f"Q_{dim}"]
|
||||
if forward:
|
||||
@ -669,12 +782,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
rank: int) -> Tensor:
|
||||
"""
|
||||
Smooths and normalizes to mean=1 a tensor of diagonalized parameter rms values for one dimension
|
||||
of a tensor. This will be used to construct a learning-rate matrix.
|
||||
of a tensor, or a batch of such diagonalized rms values.
|
||||
This will be used to construct a learning-rate matrix.
|
||||
|
||||
Args:
|
||||
group: dict for configuration values
|
||||
rms: a tensor with one dimension, representing diagonalized parameter
|
||||
root-mean-square values.
|
||||
rms: Either a tensor with shape (size,) representing diagonalized parameter
|
||||
root-mean-square values, or a tensor of shape (batch_size, size).
|
||||
rank: the assumed rank of the covariance matrix from which rms was derived,
|
||||
used to decide how much smoothing to apply. This is actually the
|
||||
minimal rank, if the parameter matrix were stay fixed during training,
|
||||
@ -688,7 +802,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
rms = rms ** param_pow
|
||||
smooth = (smooth0 +
|
||||
(smooth1 - smooth0) * size / (size + rank))
|
||||
mean = rms.mean()
|
||||
batch = (rms.ndim > 1)
|
||||
mean = _mean(rms,
|
||||
exclude_dim=0 if batch else None,
|
||||
keepdim=True)
|
||||
rms += eps + smooth * mean
|
||||
new_mean = (eps + (smooth + 1) * mean)
|
||||
ans = rms / new_mean
|
||||
@ -697,12 +814,58 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# Apply max_rms
|
||||
max_rms = 5.0
|
||||
ans.clamp_(max=max_rms*2)
|
||||
ans /= ans.mean()
|
||||
ans /= _mean(ans, exclude_dim=0 if batch else None,
|
||||
keepdim=True)
|
||||
ans.clamp_(max=max_rms)
|
||||
ans /= ans.mean()
|
||||
ans /= _mean(ans, exclude_dim=0 if batch else None,
|
||||
keepdim=True)
|
||||
return ans
|
||||
|
||||
|
||||
def _diag(x: Tensor):
|
||||
"""
|
||||
like torch diag(), but supports batch dim, i.e. input of shape (B, M, M) returns
|
||||
output of shape (B, M)
|
||||
"""
|
||||
if x.ndim == 3:
|
||||
(B, M, M2) = x.shape
|
||||
assert M == M2
|
||||
stride = x.stride()
|
||||
return x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2])).contiguous()
|
||||
else:
|
||||
return x.diag()
|
||||
|
||||
def _sum(x: Tensor,
|
||||
exclude_dims: List[int] = [],
|
||||
keepdim: bool = False):
|
||||
"""
|
||||
Version of torch.sum where you specify dims to exclude, not dims to sum.
|
||||
"""
|
||||
if len(exclude_dim) == 0:
|
||||
return x.sum(keepdim=keepdim)
|
||||
elif x.ndim == 1:
|
||||
return x # if one dim is excluded, there are no dims to sum, and
|
||||
# x.sum(dim=[]) sums all dims so we should not call sum().
|
||||
exclude_dims = [i if i >= 0 else i + x.ndim for i in exclude_dims]
|
||||
dims = [i for i in range(x.ndim) if i not in exclude_dim]
|
||||
return x.sum(dim=dims, keepdim=keepdim)
|
||||
|
||||
def _mean(x: Tensor,
|
||||
exclude_dims: List[int] = [],
|
||||
keepdim: bool = False):
|
||||
"""
|
||||
Version of torch.mean where you specify dims to exclude, not dims to mean.
|
||||
"""
|
||||
if len(exclude_dim) == 0:
|
||||
return x.mean(keepdim=keepdim)
|
||||
elif x.ndim == 1:
|
||||
return x # if one dim is excluded, there are no dims to mean, and
|
||||
# x.mean(dim=[]) averages all dims so we should not call mean().
|
||||
exclude_dims = [i if i >= 0 else i + x.ndim for i in exclude_dims]
|
||||
dims = [i for i in range(x.ndim) if i not in exclude_dim]
|
||||
return x.mean(dim=dims, keepdim=keepdim)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user