Committing partial work..

This commit is contained in:
Daniel Povey 2022-07-10 15:46:32 -07:00
parent d25df4af5e
commit 27da50a1f6

View File

@ -119,8 +119,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
loss = closure() loss = closure()
for group in self.param_groups: for group in self.param_groups:
eps = group["eps"] do_batch = False
size_update_period = group["size_update_period"]
if len([ p if p.grad is None for p in group.params ]) != 0:
self._init_group_state(group)
for p in group["params"]: for p in group["params"]:
if p.grad is None: if p.grad is None:
@ -136,77 +138,136 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
self._init_state(group, p, state)
state["step"] = 0 if not do_batch:
self._step_one_param(group, p, state, batch=False)
kwargs = {'device':p.device, 'dtype':p.dtype} if do_batch:
batches = self._get_batches(group)
# 'delta' implements conventional momentum. There are
# several different kinds of update going on, so rather than
# compute "exp_avg" like in Adam, we store and decay a
# parameter-change "delta", which combines all forms of
# update. this is equivalent to how it's done in Adam,
# except for the first few steps.
state["delta"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if p.numel() > 1:
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
param_rms = (p**2).mean().sqrt().add_(eps)
state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros((), **kwargs)
# scalar exp_avg_sq. not the same as scale_exp_avg_sq, this is used to
# determine the magnitude of the regular step
state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs)
state["scale_grads"] = torch.zeros(size_update_period,
**kwargs)
# exp_avg_sq is in the projected space. It is periodically zeroed, and we
# set zero_step.
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
state["zero_step"] = 0
for dim in range(p.ndim):
size = p.shape[dim]
if size == 1 or size == p.numel():
continue
# Q_{dim} is be the learning-rate matrix for this
# dimension, a matrix of indexed [diagonalized_coordinate, canonical_coordinate].
# this is used twice in the update step, once transposed and once without transpose.
state[f"Q_{dim}"] = torch.eye(size, **kwargs)
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
# all other dims as a batch axis.
state[f"param_cov_{dim}"] = torch.zeros(size, size, **kwargs)
# grad_cov_{dim} is the covariance of gradients on this axis (without
# any co-ordinate changes), treating all other axes as as a batch axis.
# This is needed when we re-diagonalize, and to compute the
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
# only for purposes of computing the scalar factor on the
# learning rate; and, as a result of this, also contributes something
# to the gradient w.r.t. f"Q_{dim}", which is one reason
# why we allocate a variable to keep track of its moving average
# instead of just using a temporary and smoothing the scalar factor.
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
self._step_one_param(group, p, state)
return loss return loss
def _get_batches(self,
group: dict):
"""
Arrange the parameters in the group into batches, and return a "fake" version of
group["params"] that contains the parameters stacked into batches, arranged by
size.
"""
is_initialized = False
if "batches" not in group:
self._init_batches(self, group)
def _init_batches(self,
group: dict):
"""
Arrange the parameters in this parameter group by dtype and shape, and group
them into batches; this involves reallocating the tensors that are members
of self.state[p] for each param, so that there can be an underlying group
"""
batches = []
for (dtype,shape), params in itertools.groupby(group["params"],
lambda p: (p.dtype, p.shape)):
# (dtype,shape) is the "key" of the grouping, "params" is the list
# of parameters that share the dtype and shape. group["params"] is
# a list of all the parameters in this group.
#param_batches =
def _init_group_state(self,
group: dict,
batch: bool):
"""
Initializes state dict for parameter group `group`. Expects all elements of the state
to be uninitialized, i.e. you cannot have gradient for some parameters now and other
parameters later, it must be all or nothing. This is also a requirement for DDP,
so it's not much inconvenience.
Args:
group: Parameter group to initialize state for. A list of nn.Parameter.
batch: if True, we will arrange parameters into batches for more efficient
optimizatin.
"""
eps = group["eps"]
size_update_period = group["size_update_period"]
for p in group["params"]:
assert p.grad is not None and "All parameters must have grads, or none."
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"PrAdam optimizer does not support sparse gradients"
)
state = self.state[p]
state["step"] = 0
kwargs = {'device':p.device, 'dtype':p.dtype}
# 'delta' implements conventional momentum. There are
# several different kinds of update going on, so rather than
# compute "exp_avg" like in Adam, we store and decay a
# parameter-change "delta", which combines all forms of
# update. this is equivalent to how it's done in Adam,
# except for the first few steps.
state["delta"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if p.numel() > 1:
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
param_rms = (p**2).mean().sqrt().add_(eps)
state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros((), **kwargs)
# scalar exp_avg_sq. not the same as scale_exp_avg_sq, this is used to
# determine the magnitude of the regular step
state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs)
state["scale_grads"] = torch.zeros(size_update_period,
**kwargs)
# exp_avg_sq is in the projected space. It is periodically zeroed, and we
# set zero_step.
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
state["zero_step"] = 0
for dim in range(p.ndim):
size = p.shape[dim]
if size == 1 or size == p.numel():
continue
# Q_{dim} is be the learning-rate matrix for this
# dimension, a matrix of indexed [diagonalized_coordinate, canonical_coordinate].
# this is used twice in the update step, once transposed and once without transpose.
state[f"Q_{dim}"] = torch.eye(size, **kwargs)
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
# all other dims as a batch axis.
state[f"param_cov_{dim}"] = torch.zeros(size, size, **kwargs)
# grad_cov_{dim} is the covariance of gradients on this axis (without
# any co-ordinate changes), treating all other axes as as a batch axis.
# This is needed when we re-diagonalize, and to compute the
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
# only for purposes of computing the scalar factor on the
# learning rate; and, as a result of this, also contributes something
# to the gradient w.r.t. f"Q_{dim}", which is one reason
# why we allocate a variable to keep track of its moving average
# instead of just using a temporary and smoothing the scalar factor.
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
def _step_one_param(self, def _step_one_param(self,
group: dict, group: dict,
p: Tensor, p: Tensor,
state: dict): state: dict,
batch: bool):
lr = group["lr"] lr = group["lr"]
size_lr = lr * group["size_lr_scale"] size_lr = lr * group["size_lr_scale"]
beta1, beta2 = group["betas"] beta1, beta2 = group["betas"]
@ -221,33 +282,37 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
step = state["step"] step = state["step"]
delta = state["delta"] delta = state["delta"]
delta.mul_(beta1) delta.mul_(beta1)
numel = p.numel() numel = p.numel() // (p.shape[0] if batch else 1)
if numel > 1: if numel > 1:
# Update the size/scale of p, and set param_rms # Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"] scale_grads = state["scale_grads"]
scale_grads[step % size_update_period] = (p * grad).sum() scale_grads[step % size_update_period] = _sum(p * grad,
exclude_dims=[0] if batch else [])
if step % size_update_period == size_update_period - 1: if step % size_update_period == size_update_period - 1:
# this learns the overall scale on the parameter, by shrinking or # this learns the overall scale on the parameter, by shrinking or
# expanding it. # expanding it.
param_rms = state["param_rms"] param_rms = state["param_rms"]
param_rms.copy_((p ** 2).mean().sqrt().clamp_(min=eps)) param_rms.copy_(_mean(p ** 2, exclude_dims=[0] if batch else []).sqrt().clamp_(min=eps))
if step > 0: if step > 0:
self._size_update(p, state, self._size_update(p, state,
scale_grads, param_rms, scale_grads, param_rms,
beta1, beta2, step, size_lr, beta1, beta2, step, size_lr,
param_min_rms, param_max_rms) param_min_rms, param_max_rms,
batch)
if numel == 1: if numel == 1:
# For parameters with very few elements we just use a form # For parameters with very few elements we just use a form
# of Adam with a scale factor to reflect the overall # of Adam with a scale factor to reflect the overall
# parameter rms. Updates delta. # parameter rms. Updates delta.
self._step_scalar(scalar_max, beta1, beta2, eps, lr, p, grad, state) self._step_scalar(scalar_max, beta1, beta2,
eps, lr, p, grad, state,
batch)
else: else:
if step % lr_update_period == 0 and step > 0: if step % lr_update_period == 0 and step > 0:
self._accum_param_covs(group, p, state) self._accum_param_covs(group, p, state, batch)
self._update_lrs(group, p, state) self._update_lrs(group, p, state, batch)
self._zero_exp_avg_sq(state) self._zero_exp_avg_sq(state)
self._step(group, p, grad, state) self._step(group, p, grad, state, batch)
p.add_(delta) p.add_(delta)
state["step"] = step + 1 state["step"] = step + 1
@ -262,7 +327,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
step: int, step: int,
size_lr: float, size_lr: float,
param_min_rms: float, param_min_rms: float,
param_max_rms: float) -> None: param_max_rms: float,
batch: bool) -> None:
""" """
Called only where p.numel() > 1, this updates the scale of the parameter. Called only where p.numel() > 1, this updates the scale of the parameter.
If we imagine: p = underlying_param * scale.exp(), and we are doing If we imagine: p = underlying_param * scale.exp(), and we are doing
@ -281,11 +347,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
size_lr: The learning rate to apply to `scale` size_lr: The learning rate to apply to `scale`
param_min_rms: User-specified minimum on the rms value of p, we will constrain it to >= this param_min_rms: User-specified minimum on the rms value of p, we will constrain it to >= this
param_max_rms: User-specified maximum on the rms value of p, we will constrain it to <= this param_max_rms: User-specified maximum on the rms value of p, we will constrain it to <= this
batch: If true, this is actually a batch of parameters stacked together,
and the 1st dim is the batch dim.
""" """
if batch:
# scale_grads is a tensor of shape (size_update_period,) containing a list of batch_size, size_update_period = scale_grads.shape
# products (p * grad).sum() for the `size_update_period` most recent steps else:
size_update_period, = scale_grads.shape # scale_grads is a tensor of shape (size_update_period,) containing a list of
# products (p * grad).sum() for the `size_update_period` most recent steps
size_update_period, = scale_grads.shape
# correct beta2 for the size update period: we will have # correct beta2 for the size update period: we will have
# faster decay at this level. # faster decay at this level.
@ -293,7 +363,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
scale_exp_avg_sq = state["scale_exp_avg_sq"] scale_exp_avg_sq = state["scale_exp_avg_sq"]
scale_exp_avg_sq.mul_(beta2_corr).add_( scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads ** 2).mean(), alpha=1-beta2_corr) _mean(scale_grads ** 2, exclude_dims=[0] if batch else []),
alpha=1-beta2_corr)
# The 1st time we reach here is when size_step == 1. # The 1st time we reach here is when size_step == 1.
@ -305,7 +376,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
eps = 1.0e-10 # this value is not critical. eps = 1.0e-10 # this value is not critical.
denom = scale_exp_avg_sq.sqrt() + eps denom = scale_exp_avg_sq.sqrt() + eps
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom scale_step = -size_lr * (bias_correction2 ** 0.5) * _sum(scale_grads,
exclude_dims=[0] if batch else [],
keepdim=True) / denom
is_too_small = (param_rms < param_min_rms) is_too_small = (param_rms < param_min_rms)
is_too_large = (param_rms > param_max_rms) is_too_large = (param_rms > param_max_rms)
@ -320,7 +393,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
def _accum_param_covs(self, def _accum_param_covs(self,
group: dict, group: dict,
p: Tensor, p: Tensor,
state: dict) -> None: state: dict,
batch: bool) -> None:
""" """
Updates state[f"param_cov_{dim}"] for each dim of tensor p. Updates state[f"param_cov_{dim}"] for each dim of tensor p.
""" """
@ -329,20 +403,27 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
step = state["step"] step = state["step"]
this_weight = (group["param_cov_freshness"] * lr_update_period / this_weight = (group["param_cov_freshness"] * lr_update_period /
(step + lr_update_period)) (step + lr_update_period))
for dim in range(p.ndim): for dim in range(p.ndim):
size = p.shape[dim] size = p.shape[dim]
if size == 1 or size == p.numel(): if size == 1 or size == p.numel() or (batch and dim == 0):
continue continue
# M is the parameter matrix, of shape (-1, size) where the -1 covers # M is the parameter matrix, of shape (-1, size) where the -1 covers
# all other dimensions of the tensor. # all other dimensions of the tensor.
M_full = p.transpose(dim, -1) M_full = p.transpose(dim, -1)
M = M_full.reshape(-1, size) if batch: # dim 0 is batch dim
M = M_full.reshape(p.shape[0], -1, size)
else:
M = M_full.reshape(-1, size)
# will be batched matrix multiply if batch==True
this_param_cov = torch.matmul(M.t(), M) this_param_cov = torch.matmul(M.t(), M)
# normalize scale of this param_cov, in casre parameter scale # normalize scale of this param_cov, in case parameter scale
# changes significantly during training, which would cause some # changes significantly during training, which would cause some
# parts of the training timeline to be more highly weighted. # parts of the training timeline to be more highly weighted.
this_param_cov /= (this_param_cov.diag().mean() + eps) this_param_cov /= _mean(_diag(this_param_cov),
exclude_dims=[0] if batch else [],
keepdim=True) + eps
state[f"param_cov_{dim}"].mul_(1-this_weight).add_(this_param_cov, state[f"param_cov_{dim}"].mul_(1-this_weight).add_(this_param_cov,
alpha=this_weight) alpha=this_weight)
@ -359,7 +440,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
def _update_lrs(self, def _update_lrs(self,
group: dict, group: dict,
p: Tensor, p: Tensor,
state: dict) -> None: state: dict,
batch: bool) -> None:
""" """
Computes learning-rate matrices Q for each dim of this tensor. Computes learning-rate matrices Q for each dim of this tensor.
Args: Args:
@ -370,9 +452,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
state: state dict for the current parameter state: state dict for the current parameter
""" """
ndim = p.ndim ndim = p.ndim
numel = p.numel() numel = p.numel() // (p.shape[0] if batch else 1)
if numel in p.shape: if numel in (p[0].shape if batch else p.shape)
return # Nothing to do for this parameter matrix. E.g. a bias or a scalar. return # Nothing to do for this parameter matrix. E.g. a bias or a scalar.
scale_arr = [None] * ndim scale_arr = [None] * ndim
@ -382,21 +464,30 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
eps = 1.0e-20 eps = 1.0e-20
cur_p = p + eps * torch.randn_like(p) cur_p = p + eps * torch.randn_like(p)
dims_to_sum = list(range(ndim-1))
for dim in range(ndim): for dim in range(ndim):
size = p.shape[dim] size = p.shape[dim]
if size == 1 or size == p.numel(): if size == 1 or size == p.numel() or (batch and dim == 0):
continue continue
param_cov = state[f"param_cov_{dim}"] param_cov = state[f"param_cov_{dim}"]
# if batch==True, U, S and V have an extra leading dim.
U, S, V = _svd(param_cov) U, S, V = _svd(param_cov)
Q = state[f"Q_{dim}"] Q = state[f"Q_{dim}"]
Q[:] = U.t() Q[:] = U.transpose(-1,-2) # 3 dims if batch, else 2.
M = cur_p.transpose(dim, -1) M = cur_p.transpose(dim, -1)
if batch:
while U.ndim < M.ndim:
U = U.unsqueeze(1)
N = torch.matmul(M, U) N = torch.matmul(M, U)
cur_param_var = (N * N).mean(dim=dims_to_sum) # cur_param_var is a diagonal parameter variance over dimension `dim`,
# OK, cur_param_var would be the same as S if the variance stats # will have shape something like [1, size, 1]; or [batch_size, 1, size, 1]
# if batch == True.
cur_param_var = _mean(N**2,
exclude_dims=[0,dim] if batch else [dim],
keepdim=True)
S = S.reshape(cur_param_var.shape)
# OK, cur_param_var would have the values as S if the variance stats
# param_cov_{dim} were accumulated from this exact parameter matrix, # param_cov_{dim} were accumulated from this exact parameter matrix,
# but actually they contain other versions of the parameter # but actually they contain other versions of the parameter
# covariance so they will, in general, be less extreme. We scale # covariance so they will, in general, be less extreme. We scale
@ -404,6 +495,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# ensure it doesn't have any too-small eigenvalues (where the # ensure it doesn't have any too-small eigenvalues (where the
# stats permit). # stats permit).
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt() scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
if batch:
batch_size = p.shape[0]
scale = scale.reshape(batch_size, 1, size)
else:
scale = scale.reshape(size)
N *= scale N *= scale
cur_p = N.transpose(dim, -1) cur_p = N.transpose(dim, -1)
@ -431,16 +528,17 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
for i in range(4): # for 4 iterations.. for i in range(4): # for 4 iterations..
for dim in range(ndim): for dim in range(ndim):
size = p.shape[dim] size = p.shape[dim]
if size == 1 or size == p.numel(): if size == 1 or size == p.numel() or (batch and dim == 0):
continue continue
M = cur_p.transpose(dim, -1) M = cur_p.transpose(dim, -1)
# correct for the fact that we have already normalized this dim in cur_p. # correct for the fact that we have already normalized this dim in cur_p.
if cur_scales[dim] is not None: if cur_scales[dim] is not None:
M *= cur_scales[dim] M *= cur_scales[dim]
rms = (M**2).mean(dim=dims_to_sum).sqrt() rms = _mean(M**2,
exclude=[0,dim] if batch else dim).sqrt()
rank = numel // size rank = numel // size
smoothed_rms = self._smooth_param_rms(group, rms, rank) smoothed_rms = self._smooth_param_rms(group, rms, rank, batch)
cur_scales[dim] = smoothed_rms cur_scales[dim] = smoothed_rms
M /= smoothed_rms # normalize this dim.. M /= smoothed_rms # normalize this dim..
@ -525,12 +623,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
Q[:] = torch.matmul(U.t(), Q) Q[:] = torch.matmul(U.t(), Q)
def _step(self, def _step(self,
group: dict, group: dict,
p: Tensor, p: Tensor,
grad: Tensor, grad: Tensor,
state: dict): state: dict,
batch: bool):
""" """
This function does the core update of self._step, in the case where the tensor This function does the core update of self._step, in the case where the tensor
has more than 1 elements. It multiplies the moving-average gradient tensor by a has more than 1 elements. It multiplies the moving-average gradient tensor by a
@ -542,6 +640,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
p: The parameter to be updated p: The parameter to be updated
grad: The grad of p grad: The grad of p
state: The state-dict corresponding to parameter p state: The state-dict corresponding to parameter p
batch: if true, the first dimension of p is a batch dim, i.e. a batch of
parameter matrices.
This function modifies p. This function modifies p.
""" """
@ -557,17 +657,21 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# grad_cov_{dim}, which are for periodically updating the # grad_cov_{dim}, which are for periodically updating the
# learning-rate matrices. # learning-rate matrices.
size = grad.shape[dim] size = grad.shape[dim]
if size == 1 or size == p.numel(): if size == 1 or size == p.numel() or (batch and dim == 0):
continue continue
if step % grad_cov_period == 0 or step < grad_cov_period: if step % grad_cov_period == 0 or step < grad_cov_period:
grad_cov = state[f"grad_cov_{dim}"] grad_cov = state[f"grad_cov_{dim}"] # shaped: (batch, size, size) if batch, else (size, size)
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M if batch:
this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad this_m = p.transpose(-1, dim).reshape(p.shape[0], -1, size) # parameter matrix M (batch)
this_g = grad.transpose(-1, dim).reshape(p.shape[0], -1, size) # M_grad (batch)
else:
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad
# could perhaps accumulate grad_cov less frequently; it's only # could perhaps accumulate grad_cov less frequently; it's only
# needed when we rediagonalize which is not that common. # needed when we rediagonalize which is not that common.
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g)) grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
grad = self._project(grad, state, forward=True) grad = self._project(grad, state, batch, forward=True)
exp_avg_sq = state["exp_avg_sq"] exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
@ -583,10 +687,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
grad = grad / denom grad = grad / denom
# and project back.. # and project back..
grad = self._project(grad, state, forward=False) grad = self._project(grad, state, batch, forward=False)
scalar_exp_avg_sq = state["scalar_exp_avg_sq"] scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
scalar_exp_avg_sq.mul_(beta2).add_((grad**2).mean(), scalar_exp_avg_sq.mul_(beta2).add_(_mean(grad**2,
exclude_dims=[0] if batch else [],
keepdim=batch),
alpha=1-beta2) alpha=1-beta2)
@ -595,12 +701,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# use the same bias_correction2. # use the same bias_correction2.
scalar_exp_avg_sq = scalar_exp_avg_sq * (1.0 / bias_correction2) scalar_exp_avg_sq = scalar_exp_avg_sq * (1.0 / bias_correction2)
denom = scalar_exp_avg_sq.sqrt() + eps # scalar. denom = scalar_exp_avg_sq.sqrt() + eps # scalar, or [batch_size,1,1..]
alpha = state["param_rms"] * (1-beta1) * -lr / denom alpha = state["param_rms"] * (1-beta1) * -lr / denom
delta = state["delta"] delta = state["delta"]
delta.add_(grad, alpha=alpha) if batch:
delta.add_(grad * alpha)
else:
delta.add_(grad, alpha=alpha)
def _step_scalar(self, def _step_scalar(self,
@ -611,19 +720,20 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
lr: float, lr: float,
p: Tensor, p: Tensor,
grad: Tensor, grad: Tensor,
state: dict): state: dict,
batch: bool):
""" """
A form of the core update for scalar tensors, where we cannot get a good A form of the core update for scalar tensors, where we cannot get a good
estimate of the parameter rms. estimate of the parameter rms.
""" """
exp_avg_sq = state["exp_avg_sq"] exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) if batch else (0
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=1-beta2) value=1-beta2)
# bias_correction2 is like in Adam. Don't bother with bias_correction1; # bias_correction2 is like in Adam. Don't bother with bias_correction1;
# slower update at the start will help stability anyway. # slower update at the start will help stability anyway.
bias_correction2 = 1 - beta2 ** (state["step"] + 1) bias_correction2 = 1 - beta2 ** (state["step"] + 1)
denom = (exp_avg_sq.sum() / (exp_avg_sq.numel() * bias_correction2)).sqrt() + eps denom = (exp_avg_sq / bias_correction2).sqrt() + eps
alpha = -lr * (1-beta1) / denom alpha = -lr * (1-beta1) / denom
delta = state["delta"] delta = state["delta"]
@ -633,6 +743,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
def _project(self, def _project(self,
x: Tensor, x: Tensor,
state: dict, state: dict,
batch: bool,
forward: bool) -> Tensor: forward: bool) -> Tensor:
""" """
Multiply a tensor x by proj_{dim} on each of its dimensions. Multiply a tensor x by proj_{dim} on each of its dimensions.
@ -641,6 +752,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
Args: Args:
x: The tensor to project, e.g. a gradient or a parameter change. x: The tensor to project, e.g. a gradient or a parameter change.
state: dict to look up the projections
batch: if True, the 1st dim of x is a batch dim and is not projected.
forward: if True, go in the forward directin (from canonical to diagnonalized forward: if True, go in the forward directin (from canonical to diagnonalized
co-ordinates); if False, the reverse. They differ by a transpose, co-ordinates); if False, the reverse. They differ by a transpose,
not an inverse, and do not make a round trip. not an inverse, and do not make a round trip.
@ -648,7 +761,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
numel = x.numel() numel = x.numel()
for dim in range(x.ndim): for dim in range(x.ndim):
size = x.shape[dim] size = x.shape[dim]
if size == 1 or size == numel: if size == 1 or size == numel or (batch and dim == 0):
continue continue
Q = state[f"Q_{dim}"] Q = state[f"Q_{dim}"]
if forward: if forward:
@ -669,12 +782,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
rank: int) -> Tensor: rank: int) -> Tensor:
""" """
Smooths and normalizes to mean=1 a tensor of diagonalized parameter rms values for one dimension Smooths and normalizes to mean=1 a tensor of diagonalized parameter rms values for one dimension
of a tensor. This will be used to construct a learning-rate matrix. of a tensor, or a batch of such diagonalized rms values.
This will be used to construct a learning-rate matrix.
Args: Args:
group: dict for configuration values group: dict for configuration values
rms: a tensor with one dimension, representing diagonalized parameter rms: Either a tensor with shape (size,) representing diagonalized parameter
root-mean-square values. root-mean-square values, or a tensor of shape (batch_size, size).
rank: the assumed rank of the covariance matrix from which rms was derived, rank: the assumed rank of the covariance matrix from which rms was derived,
used to decide how much smoothing to apply. This is actually the used to decide how much smoothing to apply. This is actually the
minimal rank, if the parameter matrix were stay fixed during training, minimal rank, if the parameter matrix were stay fixed during training,
@ -688,7 +802,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
rms = rms ** param_pow rms = rms ** param_pow
smooth = (smooth0 + smooth = (smooth0 +
(smooth1 - smooth0) * size / (size + rank)) (smooth1 - smooth0) * size / (size + rank))
mean = rms.mean() batch = (rms.ndim > 1)
mean = _mean(rms,
exclude_dim=0 if batch else None,
keepdim=True)
rms += eps + smooth * mean rms += eps + smooth * mean
new_mean = (eps + (smooth + 1) * mean) new_mean = (eps + (smooth + 1) * mean)
ans = rms / new_mean ans = rms / new_mean
@ -697,12 +814,58 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# Apply max_rms # Apply max_rms
max_rms = 5.0 max_rms = 5.0
ans.clamp_(max=max_rms*2) ans.clamp_(max=max_rms*2)
ans /= ans.mean() ans /= _mean(ans, exclude_dim=0 if batch else None,
keepdim=True)
ans.clamp_(max=max_rms) ans.clamp_(max=max_rms)
ans /= ans.mean() ans /= _mean(ans, exclude_dim=0 if batch else None,
keepdim=True)
return ans return ans
def _diag(x: Tensor):
"""
like torch diag(), but supports batch dim, i.e. input of shape (B, M, M) returns
output of shape (B, M)
"""
if x.ndim == 3:
(B, M, M2) = x.shape
assert M == M2
stride = x.stride()
return x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2])).contiguous()
else:
return x.diag()
def _sum(x: Tensor,
exclude_dims: List[int] = [],
keepdim: bool = False):
"""
Version of torch.sum where you specify dims to exclude, not dims to sum.
"""
if len(exclude_dim) == 0:
return x.sum(keepdim=keepdim)
elif x.ndim == 1:
return x # if one dim is excluded, there are no dims to sum, and
# x.sum(dim=[]) sums all dims so we should not call sum().
exclude_dims = [i if i >= 0 else i + x.ndim for i in exclude_dims]
dims = [i for i in range(x.ndim) if i not in exclude_dim]
return x.sum(dim=dims, keepdim=keepdim)
def _mean(x: Tensor,
exclude_dims: List[int] = [],
keepdim: bool = False):
"""
Version of torch.mean where you specify dims to exclude, not dims to mean.
"""
if len(exclude_dim) == 0:
return x.mean(keepdim=keepdim)
elif x.ndim == 1:
return x # if one dim is excluded, there are no dims to mean, and
# x.mean(dim=[]) averages all dims so we should not call mean().
exclude_dims = [i if i >= 0 else i + x.ndim for i in exclude_dims]
dims = [i for i in range(x.ndim) if i not in exclude_dim]
return x.mean(dim=dims, keepdim=keepdim)