In the middle of reworking for new idea

This commit is contained in:
Daniel Povey 2022-07-06 13:35:19 -07:00
parent 41368f6b63
commit e9e2a85c95

View File

@ -48,7 +48,10 @@ class LearnedGradient(Optimizer):
learning the scale on the parameters (we'll keep it >= this size) learning the scale on the parameters (we'll keep it >= this size)
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll keep it <= this size) learning the scale on the parameters (we'll keep it <= this size)
lr_est_period: The periodicity, in steps, with which we update the learning-rate matrix. lr_est_period: The periodicity, in steps, with which we update the learning-rate matrices.
orthogonalize_period: How often we re-orthogonalize the gradient covariance; this
gets multiplied by lr_est_period, i.e. we orthogonalize less frequently
than we update the learning-rate matrixes.
max_step_scale: Value that determines limit on how large steps can be per element, max_step_scale: Value that determines limit on how large steps can be per element,
intended to prevent instability during early steps before the learning-rate intended to prevent instability during early steps before the learning-rate
matrices are well learned. matrices are well learned.
@ -58,12 +61,14 @@ class LearnedGradient(Optimizer):
params, params,
lr=3e-02, lr=3e-02,
size_lr_scale=0.1, size_lr_scale=0.1,
meta_lr_scale=1.0, meta_lr_scale=0.1,
betas=(0.9, 0.98), betas=(0.9, 0.98),
eps=1.0e-08, eps=1.0e-08,
param_min_rms=1.0e-05, param_min_rms=1.0e-05,
param_max_rms=2.0, param_max_rms=2.0,
lr_est_period=10, lr_est_period=10,
max_lr_eig=4.0,
orthogonalize_period=4,
max_step_scale=2.0 max_step_scale=2.0
): ):
@ -77,6 +82,8 @@ class LearnedGradient(Optimizer):
param_min_rms=param_min_rms, param_min_rms=param_min_rms,
param_max_rms=param_max_rms, param_max_rms=param_max_rms,
lr_est_period=lr_est_period, lr_est_period=lr_est_period,
max_lr_eig=max_lr_eig,
orthogonalize_period=orthogonalize_period,
max_step_scale=max_step_scale, max_step_scale=max_step_scale,
) )
@ -104,16 +111,15 @@ class LearnedGradient(Optimizer):
meta_lr_scale = group["meta_lr_scale"] meta_lr_scale = group["meta_lr_scale"]
size_lr = lr * group["size_lr_scale"] size_lr = lr * group["size_lr_scale"]
beta1, beta2 = group["betas"] beta1, beta2 = group["betas"]
max_lr_eig = group["max_lr_eig"]
eps = group["eps"] eps = group["eps"]
size_update_period = 4 size_update_period = 4
param_min_eps = group["param_min_rms"] param_min_eps = group["param_min_rms"]
param_max_eps = group["param_max_rms"] param_max_eps = group["param_max_rms"]
lr_est_period = group["lr_est_period"] lr_est_period = group["lr_est_period"]
orthogonalize_period = group["orthogonalize_period"]
max_step_scale = group["max_step_scale"] max_step_scale = group["max_step_scale"]
small_tensor_cutoff = 5 # cutoff below which we use Adam
for p in group["params"]: for p in group["params"]:
if p.grad is None: if p.grad is None:
continue continue
@ -134,60 +140,69 @@ class LearnedGradient(Optimizer):
kwargs = {'device':p.device, 'dtype':p.dtype} kwargs = {'device':p.device, 'dtype':p.dtype}
if p.numel() > 1: # 'delta' implements conventional momentum. There are
# we learn the scale on the parameters in a way that # several different kinds of update going on, so rather than
# also emulates Eve. scale_exp_avg_sq is the # compute "exp_avg" like in Adam, we store and decay a
# moving-average square of the gradient w.r.t. this # parameter-change "delta", which combines all forms of
# scalar scale. # update. this is equivalent to how it's done in Adam,
state["scale_exp_avg_sq"] = torch.zeros((), **kwargs) # except for the first few steps.
state["scale_exp_avg"] = torch.zeros((), **kwargs) state["delta"] = torch.zeros_like(
# for speed, we do the scale update only periodically,
# and meanwhile we store the derivs w.r.t. the grad scale
# as a list
state["scale_grads"] = torch.zeros(size_update_period, **kwargs)
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format p, memory_format=torch.preserve_format
) )
if p.numel() > small_tensor_cutoff: if p.numel() > 1:
# moving_grad is a moving average of the raw gradient. # "param_rms" just periodically records the scalar root-mean-square value of
# We do the accumulation on the input side because it # the parameter tensor.
# makes it easier to get grads w.r.t. the learning rate
# matrices.
# "scale" is just a per-tensor scalar that determines the
# rate of learning (in terms of the rms value of the change
# in parameter)
param_rms = (p**2).mean().sqrt().add_(eps) param_rms = (p**2).mean().sqrt().add_(eps)
state["param_rms"] = param_rms state["param_rms"] = param_rms
# scalar exp_avg_sq. not the same as scale_exp_avg_sq
state["exp_avg_sq_scalar"] = torch.zeros_like((), **kwargs)
for dim in range(p.ndim): for dim in range(p.ndim):
size = p.shape[dim] size = p.shape[dim]
if size == 1 or size == p.numel(): if size == 1:
# if size == p.numel(), is_one_axis == True above # if size == p.numel(), is_one_axis == True above
continue continue
state[f"lr_{dim}"] = torch.eye(size, **kwargs) # exp_avg_sq is in the projected space.
# Some parts of the update -- the parts where we are state["exp_avg_sq"] = torch.zeros_like(
# doing to update lr_{dim} will require grad. these p, memory_format=torch.preserve_format
# parts will be done inside torch.enable_grad(). )
# The rest of this function is inside no_grad(), so
# it will ignore the requires_grad == True.
state[f"lr_{dim}"].requires_grad = True # proj_{dim} will be the learning-rate matrix for this
# dimension, times an orthogonal matrix that we'll periodically
# re-estimate when we "reorthogonalize" (this orthogonalizes
# the gradient covariance).
state[f"proj_{dim}"] = torch.eye(size, **kwargs)
# proj_grad_{dim} is the gradient w.r.t. proj_{grad}
# (technically w.r.t a matrix to the left of
# proj_{grad}, we'll project it later); we accumulate
# this for `lr_est_period` steps, and then update
# proj_{dim} and zero this matrix
state[f"proj_grad_{dim}"] = torch.zeros(size, size, **kwargs)
# grad_cov_{dim} is the covariance of gradients on this axis (without
# any co-ordinate changes), treating all other axes as as a batch axis.
# This is needed when we re-orthogonalize, and to compute the
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
# grad_cov_{dim} is the covariance of gradients on this axis,
# treating all other axes as as a batch axis. This is needed
# only for purposes of computing the scalar factor on the # only for purposes of computing the scalar factor on the
# learning rate; and, as a result of this, also contributes something # learning rate; and, as a result of this, also contributes something
# to the gradient w.r.t. f"lr_{dim}", which is one reason # to the gradient w.r.t. f"lr_{dim}", which is one reason
# why we allocate a variable to keep track of its moving average # why we allocate a variable to keep track of its moving average
# instead of just using a temporary and smoothing the scalar factor. # instead of just using a temporary and smoothing the scalar factor.
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs) state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
else:
# size <= small_tensor_cutoff: do a form of Adam # proj_grad_var_{dim} is the variance of gradients of the thing we're
state["exp_avg_sq"] = torch.zeros_like(p, # going to multiply proj_{dim} by (between it and the "underlying
mamory_format=preserve_memory_format) # parameter matrix"), aggregated over each of its 2 dimensions
# and decayed over time with beta1 each time we do a step on proj_{dim},
# e.g. every lr_est_period steps. It's a factored representation
# of something like exp_avg_sq, used for estimating proj_{dim}; we avoid
# storing an exp_avg_sq for proj_{dim}, to save GPU memory.
state[f"proj_grad_var_{dim}"] = torch.ones(2, size, **kwargs)
step = state["step"] step = state["step"]
@ -206,14 +221,14 @@ class LearnedGradient(Optimizer):
beta1, beta2, step, size_lr, beta1, beta2, step, size_lr,
param_min_rms, param_max_rms) param_min_rms, param_max_rms)
if numel <= small_tensor_cutoff: if numel == 1:
# For parameters with very few elements we just use a form # For parameters with very few elements we just use a form
# of Adam with a scale factor to reflect the overall # of Adam with a scale factor to reflect the overall
# parameter rms. # parameter rms.
self._step_small(beta1, beta2, eps, lr, p, grad, state) self._step_scalar(beta1, beta2, eps, lr, p, grad, state)
else: else:
if step % lr_est_period == 0: if step % lr_est_period == 0:
self._update_lrs(group, grad, state) self._update_lrs(group, p, state)
self._step(group, p, grad, state) self._step(group, p, grad, state)
state["step"] = step + 1 state["step"] = step + 1
@ -289,28 +304,129 @@ class LearnedGradient(Optimizer):
def _update_lrs(self, def _update_lrs(self,
group: dict, group: dict,
grad: Tensor, p: Tensor,
state: dict) -> None: state: dict) -> None:
""" """
Updates the learning-rate matrices state["lr_{dim}"]. Updates the learning-rate matrices state["lr_{dim}"].
Args: Args:
group: dict to look up configuration values group: dict to look up configuration values
grad: current gradient for the parameter that we are updating p: parameter matrix that we are updating. The learning rate matrices
are actually factors of p, so p itself will change when we change
them.
state: state dict for the current parameter state: state dict for the current parameter
""" """
# meta_lr is the learning rate for the learning rate matrix. # meta_lr is the learning rate for the learning rate matrix. Including the factor
meta_lr = group["lr"] * group["meta_lr_scale"] # (group["lr_est_period"] ** 0.5) is intended to make it approximately invariant
# to the lr_est_period (close to convergence).
meta_lr = group["lr"] * group["meta_lr_scale"] * (group["lr_est_period"] ** 0.5)
beta1, beta2 = group["betas"] beta1, beta2 = group["betas"]
eps = group["eps"] eps = group["eps"]
moving_g = state["exp_avg"] moving_g = state["exp_avg"]
ndim = p.ndim ndim = p.ndim
# Update grad_cov. for dim in range(ndim):
self._store_grad_stats(grad, state, beta2) size = p.shape[dim]
if size == 1:
continue
p_t = p.transpose(dim, -1)
this_p = p_t.reshape(-1, size)
# proj_grad is a summed gradient, indexed [grad_idx][param_idx] (don't worry
# if this notation about the indexing is unclear, just keep reading).
#
# Suppose M is our parameter matrix p reshaped to (size, -1) with the -1 being
# all other dims treated as a batch dim, and M_grad is the loss derivative w.r.t. M,
#
# Then proj_grad is the accumulated value of M_grad M^T, which will be of shape
# (size, size). If we view the parameter matrix as actually (proj M) with
# proj being numerically equal to I, i.e. as
# matmul(proj, M), then instead of the loss being tr(M_grad^T M), we can write
# it as tr(M_grad^T proj M), which can also be written as tr(M M_grad^T proj)
# = tr( (M_grad M^T)^T proj), where we can write M_grad M^T == proj_grad.
#
# Now, proj_grad is convenient to compute but it's not in a very normalized
# space, we want to normalize it first. We're going to view M as
# M = T N,
# where T is our proj_{dim} matrix (see the code), and N == T^{-1} M.
# It's more convenient to have the projection between T and N, i.e. have proj2,
# again numerically equal to I, so
# M == T proj2 N.
# [Be careful because when we use T as a superscript it means transpose).
# We want the derivative w.r.t. proj2.
# So suppose the loss is
# tr(M_grad^T M) == tr(M_grad^T T proj2 N) == tr(N M_grad^T T proj2)
# == tr( (T^T M_grad N^T)^T proj2 )
# == tr( (T^T M_grad N^T)^T proj2 )
# and using N == T^{-1} M, so N^T = M^T T^{-T},
# == tr( (T^T M_grad M^T T^{-T})^T proj2 )
# so we can view the thing in parentheses in the above expression as proj2_grad, i.e.
# proj2_grad = T^T M_grad M^T T^{-T}
# proj2_grad = T^T proj_grad T^{-T}
#
# note: p = state[f"proj_{dim}"] actually corresponds to T^T, since it has a
# shape interpretable as (diagonalized_index, canonical_index).
#
# So proj2_grad = p proj_grad p^{-1} (eq.1)
#
# After computing proj2_grad we compute/update a factorized representation of
# its variance, factorized over its 2 axes, and do an Adam-type update
# for it, except without momentum at this level (momentum will be applied
# later). This gives us proj2_delta, according to the
# equation: proj2 = I + proj2_delta.
#
# This is the point at which we control the magnitude of the parameter change.
#
# We now have to update the factor T using proj2, as in:
# T <-- T (I + proj2_delta)
# T <-- T + T proj2_delta
# and since p == T^T, this becomes:
# p <-- p + proj2_delta^T p
# or equivalently:
# p_delta = proj2_delta^T p (eq.2)
#
# We also need to update M itself; this is easiest done by computing proj_delta
# (which is the deviation of proj from I). The relationship between proj and proj2
# is given by equating
# proj T N = T proj2 N,
# proj T = T proj2,
# proj = T proj2 T^{-1}
# proj = p^T proj2 p^{-T} .
# we can apply the above formula directly to proj2_delta, since the "I" part just
# becomes a no-op, i.e.
# proj_delta = p^T proj2_delta p^{-T}.
# ... and to turn this back into a delta on M, we'll do:
# M_delta = proj_delta M.
proj_grad = state[f"proj_grad_{dim}"]
p = state[f"proj_{dim}"]
# torch.linalg.solve(A, B) returns X such that AX=B,
# meaning X = A^{-1} B. We want to compute X = proj_grad p^{-1},
# so X^T = p^{-T} proj_grad^T.
X = torch.linalg.solve(p.t(), proj_grad.t()).t()
assert torch.allclose(proj_grad, torch.matmul(X, p)) # TEMP.
# See (eq.1), the next line implements proj_grad = p proj_grad p^{-1}
proj2_grad = torch.matmul(p, X)
# for now just estimate it from proj2_grad, later we will try aggregating it
# across iterations. This is like exp_avg_sq in Adam.
proj2_grad_var = self._get_matrix_var(proj_grad, group["eps"])
denom = proj2_grad_var.sqrt()
alpha = -meta_lr * (1-beta1)
# we'll take into account alpha later on.
proj2_delta = proj2_grad_var / denom
p_delta = torch.matmul(proj2_delta.t(), p)
with torch.enable_grad():
# To update the learning rate matrices, we are asking the question, "What if, # To update the learning rate matrices, we are asking the question, "What if,
# on the previous step, we had used a slightly different learning-rate matrix? # on the previous step, we had used a slightly different learning-rate matrix?
@ -345,7 +461,7 @@ class LearnedGradient(Optimizer):
# [something] as it contributes no gradient, becomes -alpha * # [something] as it contributes no gradient, becomes -alpha *
# (moving_d_norm * grad).sum(), and ignoring the scale alpha, that # (moving_d_norm * grad).sum(), and ignoring the scale alpha, that
# is -(moving_d_norm * grad).sum(), so we call this neg_loss as it's the # is -(moving_d_norm * grad).sum(), so we call this neg_loss as it's the
# negative of hte iss. # negative of the loss.
neg_loss = (grad * moving_d_norm).sum() neg_loss = (grad * moving_d_norm).sum()
# after the following, there will grads in state[f"lr_{dim}"] # after the following, there will grads in state[f"lr_{dim}"]
@ -357,6 +473,32 @@ class LearnedGradient(Optimizer):
if p.shape[dim] != 1: if p.shape[dim] != 1:
self._update_lr(state[f"lr_{dim}"], meta_lr) self._update_lr(state[f"lr_{dim}"], meta_lr)
def _get_matrix_var(self,
x: Tensor,
eps: float) -> Tensor:
"""
Estimate a variance of elements of x, which must be a tensor with exactly
2 diimensions, neither of which can have size == 1. This is done using
a structured representation, so the variance is a product over the 2
dimensions (otherwise, it wouldn't be possible to estimate a variance from
a single sample). Eventually we may do this via aggregation across
time as well; this approach seems to give a high variance estimate even
for largish dimension, e.g. if x.shape == (100,100) the 95th percentile
is from about 0.6 to 1.6.
Args:
x: tensor to estimate the variance of elements of. Require x.ndim == 2,
neither size must be 1.
eps: Epsilon to avoid zero variance: epsilon is dimensionally a rms value.
"""
assert x.ndim == 2 and 1 not in list(x.shape) # TODO: remove?
var0 = (x**2).mean(dim=1, keepdim=True) + eps*eps
var1 = ((x / var0.sqrt()) ** 2).mean(dim=0, keepdim=True) + eps*eps
# esetimate var0 one more time.
var0 = ((x / var1.sqrt()) ** 2).mean(dim=1, keepdim=True) + eps*eps
return var1 * var0
def _update_lr(self, def _update_lr(self,
@ -421,7 +563,7 @@ class LearnedGradient(Optimizer):
state: dict): state: dict):
""" """
This function does the core update of self._step, in the case where the tensor This function does the core update of self._step, in the case where the tensor
has more than about 5 elements. It multiplies the moving-average gradient tensor by a has more than 1 elements. It multiplies the moving-average gradient tensor by a
state["lr_{dim}"] on each non-trivial axis/dimension, does a length normalization, state["lr_{dim}"] on each non-trivial axis/dimension, does a length normalization,
and takes a step in that direction. and takes a step in that direction.
@ -434,55 +576,60 @@ class LearnedGradient(Optimizer):
This function modifies p. This function modifies p.
""" """
lr = group["lr"] lr = group["lr"]
beta1 = group["betas"][0] beta1, beta2 = group["betas"]
eps = group["eps"] eps = group["eps"]
max_step_scale = group["max_step_scale"]
moving_g = state["exp_avg"]
# include the current grad in moving_g
moving_g.mul_(beta1).add(grad, alpha=(1-beta1))
# get the parameter delta (up to a scalar factor) by multiplying by a for dim in range(grad.ndim):
# learning-rate matrix for each dimension. size = grad.shape[size]
moving_d = self._multiply_by_lr(moving_g) # accumulate some stats for learning the projections
proj_grad = state[f"proj_grad_{dim}"]
grad_cov = state[f"grad_cov_{dim}"]
this_p = p.transpose(-1, dim).reshape(-1, size)
this_g = g.transpose(-1, dim).reshape(-1, size)
proj_grad.add_(torch.matmul(this_g.t(), this_p))
# could perhaps accumulate grad_cov less frequently.
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
# moving_d contains a moving average of parameter change contributions grad = self._project(grad, forward=True, state)
# from multiple iterations with weights: (1-beta) * (1 + beta + beta^2 +
# beta^3, etc.), which sum to 1 for large step. Suppose the rms value
# of each individual gradient were r. Assuming the terms are
# independent, with rms=r, the variance of the elements of moving_d
# would be r^2 * (1-beta)^2 * (1 + beta^2 + beta^4 + ... ) which equals
# r^2 * (1-beta)^2 / (1-beta^2), so the measured rms of moving_d would
# be moving_d_rms = individual_d_rms * (1-beta) / sqrt(1-beta^2), so
# individual_d_rms = moving_d_rms * sqrt(1-beta^2) / (1-beta) ... this
# would be the root-mean-square value of the individual deltas, if we
# were to proces them individually (if they were independent), and this
# is what we want to normalize by for greater independence of the beta
# value, i.e. so beta only affects momentum and not the overall speed
moving_d_rms = (moving_d**2).mean().sqrt() + eps
# e.g. if beta1==0.98, this formula gives a factor of 9.95, reflecting exp_avg_sq = state["exp_avg_sq"]
# that the individual rms's of the un-scaled parameter deltas without exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
# momentum are larger than their moving average. see the justification value=(1-beta2))
# in a comment above.
d_rms = moving_d_rms * (((1 - beta1 * beta1) ** 0.5) / (1 - beta1))
param_rms = state["param_rms"]
alpha = -lr * param_rms / d_rms
# apply max_step_scale to limit magnitude of change on individual step = state["step"]
# elements, this should help stability. bias_correction2 = 1 - beta2 ** (step + 1)
limit = moving_d_rms * max_step_scale if bias_correction2 < 0.99:
moving_d_clamped = moving_d.clamp_(min=-limit, max=limit) # note: not in-place.
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
denom = exp_avg_sq.sqrt() + eps
if random.random() < 0.001: grad = grad / denom
clamp_diff_rms = ((moving_d_clamped - moving_d) ** 2).mean().sqrt()
rel_clamp_diff = clamp_diff_rms / d_rms
logging.info(f"Clamping difference is {clamp_diff_rms.item():.3g} vs. {d_rms.item():.3g}: "
f"relative diff {rel_clamp_diff.item():.3g}, shape={tuple(p.shape)}")
p.add_(moving_d_clamped, alpha=alpha) # and project back..
grad = self._project(grad, forward=False, state)
def _step_small(self,
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
scalar_exp_avg_sq.mul_(beta2).add_((grad**2).mean(),
alpha=1-beta2)
if bias_correction2 < 0.99:
# note: not in-place.
scalar_exp_avg_sq = scalar_exp_avg_sq * (1.0 / bias_correction2)
denom = scalar_exp_avg_sq.sqrt() + eps # scalar.
alpha = state["param_rms"] * (1-beta1) * lr / denom
delta = state["delta"]
delta.mul_(beta1).add_(grad, alpha=alpha)
param.add_(delta)
def _step_scalar(self,
beta1: float, beta1: float,
beta2: float, beta2: float,
eps: float, eps: float,
@ -491,11 +638,11 @@ class LearnedGradient(Optimizer):
grad: Tensor, grad: Tensor,
state: dict): state: dict):
""" """
A form of the core update for tensors with a small number of elements, e.g. <5. This is A form of the core update for tensors with a small number of elements, e.g. scalars. This is
Adam where, if the numel() > 1, the learning rate is proportional to the parameter rms value. Adam where, if the numel() > 1, the learning rate is proportional to the parameter rms value.
""" """
exp_avg = state["exp_avg"] exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"] exp_avg_sq = state["scalar_exp_avg_sq"]
exp_avg.mul_(beta1).add_(grad, alpha=1-beta1) exp_avg.mul_(beta1).add_(grad, alpha=1-beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=1-beta2) value=1-beta2)
@ -505,11 +652,40 @@ class LearnedGradient(Optimizer):
bias_correction2 = 1 - beta2 ** (state["step"] + 1) bias_correction2 = 1 - beta2 ** (state["step"] + 1)
denom = (exp_avg_sq.sum() / (exp_avg_sq.numel() * bias_correction2)).sqrt() + eps denom = (exp_avg_sq.sum() / (exp_avg_sq.numel() * bias_correction2)).sqrt() + eps
alpha = -lr / denom alpha = -lr / denom
if p.numel() > 1: # for now not supporting p.numel() > 1.
alpha = alpha * state["param_rms"] #if p.numel() > 1:
# alpha = alpha * state["param_rms"]
p.add_(exp_avg, alpha=alpha) p.add_(exp_avg, alpha=alpha)
def _project(self,
x: Tensor,
forward: bool,
state: dict) -> Tensor:
"""
Multiply a tensor x by proj_{dim} on each of its dimensions.
If forward == True, converts x from canonical to preconditioned/diagonalized
co-ordinates; if forward == False, does the reverse.
Args:
x: The tensor to project, e.g. a gradient or a parameter change.
forward: if True, go in the forward directin (from canonical to diagnonalized
co-ordinates; if False, the reverse. They differ by a transpose,
not an inverse, and do not make a round trip.
"""
for dim in range(x.ndim):
if x.shape[dim] == 1:
continue
proj = state[f"proj_{dim}"]
if forward:
# dimension 0 of `proj` is diagonalized co-ordinate.
proj = proj.t()
# TODO: could possibly somehow force the output format to be unchanged.
x = x.transpose(-1, dim)
x = torch.matmul(x, proj)
x = x.transpose(-1, dim)
return x
def _store_grad_stats(self, def _store_grad_stats(self,
grad: Tensor, grad: Tensor,
state: dict, state: dict,