Committing partial work..

This commit is contained in:
Daniel Povey 2022-07-10 15:46:32 -07:00
parent d25df4af5e
commit 27da50a1f6

View File

@ -119,8 +119,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
loss = closure()
for group in self.param_groups:
eps = group["eps"]
size_update_period = group["size_update_period"]
do_batch = False
if len([ p if p.grad is None for p in group.params ]) != 0:
self._init_group_state(group)
for p in group["params"]:
if p.grad is None:
@ -136,77 +138,136 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# State initialization
if len(state) == 0:
state["step"] = 0
kwargs = {'device':p.device, 'dtype':p.dtype}
# 'delta' implements conventional momentum. There are
# several different kinds of update going on, so rather than
# compute "exp_avg" like in Adam, we store and decay a
# parameter-change "delta", which combines all forms of
# update. this is equivalent to how it's done in Adam,
# except for the first few steps.
state["delta"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if p.numel() > 1:
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
param_rms = (p**2).mean().sqrt().add_(eps)
state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros((), **kwargs)
# scalar exp_avg_sq. not the same as scale_exp_avg_sq, this is used to
# determine the magnitude of the regular step
state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs)
state["scale_grads"] = torch.zeros(size_update_period,
**kwargs)
# exp_avg_sq is in the projected space. It is periodically zeroed, and we
# set zero_step.
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
state["zero_step"] = 0
for dim in range(p.ndim):
size = p.shape[dim]
if size == 1 or size == p.numel():
continue
# Q_{dim} is be the learning-rate matrix for this
# dimension, a matrix of indexed [diagonalized_coordinate, canonical_coordinate].
# this is used twice in the update step, once transposed and once without transpose.
state[f"Q_{dim}"] = torch.eye(size, **kwargs)
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
# all other dims as a batch axis.
state[f"param_cov_{dim}"] = torch.zeros(size, size, **kwargs)
# grad_cov_{dim} is the covariance of gradients on this axis (without
# any co-ordinate changes), treating all other axes as as a batch axis.
# This is needed when we re-diagonalize, and to compute the
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
# only for purposes of computing the scalar factor on the
# learning rate; and, as a result of this, also contributes something
# to the gradient w.r.t. f"Q_{dim}", which is one reason
# why we allocate a variable to keep track of its moving average
# instead of just using a temporary and smoothing the scalar factor.
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
self._step_one_param(group, p, state)
self._init_state(group, p, state)
if not do_batch:
self._step_one_param(group, p, state, batch=False)
if do_batch:
batches = self._get_batches(group)
return loss
def _get_batches(self,
group: dict):
"""
Arrange the parameters in the group into batches, and return a "fake" version of
group["params"] that contains the parameters stacked into batches, arranged by
size.
"""
is_initialized = False
if "batches" not in group:
self._init_batches(self, group)
def _init_batches(self,
group: dict):
"""
Arrange the parameters in this parameter group by dtype and shape, and group
them into batches; this involves reallocating the tensors that are members
of self.state[p] for each param, so that there can be an underlying group
"""
batches = []
for (dtype,shape), params in itertools.groupby(group["params"],
lambda p: (p.dtype, p.shape)):
# (dtype,shape) is the "key" of the grouping, "params" is the list
# of parameters that share the dtype and shape. group["params"] is
# a list of all the parameters in this group.
#param_batches =
def _init_group_state(self,
group: dict,
batch: bool):
"""
Initializes state dict for parameter group `group`. Expects all elements of the state
to be uninitialized, i.e. you cannot have gradient for some parameters now and other
parameters later, it must be all or nothing. This is also a requirement for DDP,
so it's not much inconvenience.
Args:
group: Parameter group to initialize state for. A list of nn.Parameter.
batch: if True, we will arrange parameters into batches for more efficient
optimizatin.
"""
eps = group["eps"]
size_update_period = group["size_update_period"]
for p in group["params"]:
assert p.grad is not None and "All parameters must have grads, or none."
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"PrAdam optimizer does not support sparse gradients"
)
state = self.state[p]
state["step"] = 0
kwargs = {'device':p.device, 'dtype':p.dtype}
# 'delta' implements conventional momentum. There are
# several different kinds of update going on, so rather than
# compute "exp_avg" like in Adam, we store and decay a
# parameter-change "delta", which combines all forms of
# update. this is equivalent to how it's done in Adam,
# except for the first few steps.
state["delta"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if p.numel() > 1:
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
param_rms = (p**2).mean().sqrt().add_(eps)
state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros((), **kwargs)
# scalar exp_avg_sq. not the same as scale_exp_avg_sq, this is used to
# determine the magnitude of the regular step
state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs)
state["scale_grads"] = torch.zeros(size_update_period,
**kwargs)
# exp_avg_sq is in the projected space. It is periodically zeroed, and we
# set zero_step.
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
state["zero_step"] = 0
for dim in range(p.ndim):
size = p.shape[dim]
if size == 1 or size == p.numel():
continue
# Q_{dim} is be the learning-rate matrix for this
# dimension, a matrix of indexed [diagonalized_coordinate, canonical_coordinate].
# this is used twice in the update step, once transposed and once without transpose.
state[f"Q_{dim}"] = torch.eye(size, **kwargs)
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
# all other dims as a batch axis.
state[f"param_cov_{dim}"] = torch.zeros(size, size, **kwargs)
# grad_cov_{dim} is the covariance of gradients on this axis (without
# any co-ordinate changes), treating all other axes as as a batch axis.
# This is needed when we re-diagonalize, and to compute the
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
# only for purposes of computing the scalar factor on the
# learning rate; and, as a result of this, also contributes something
# to the gradient w.r.t. f"Q_{dim}", which is one reason
# why we allocate a variable to keep track of its moving average
# instead of just using a temporary and smoothing the scalar factor.
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
def _step_one_param(self,
group: dict,
p: Tensor,
state: dict):
state: dict,
batch: bool):
lr = group["lr"]
size_lr = lr * group["size_lr_scale"]
beta1, beta2 = group["betas"]
@ -221,33 +282,37 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
step = state["step"]
delta = state["delta"]
delta.mul_(beta1)
numel = p.numel()
numel = p.numel() // (p.shape[0] if batch else 1)
if numel > 1:
# Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"]
scale_grads[step % size_update_period] = (p * grad).sum()
scale_grads[step % size_update_period] = _sum(p * grad,
exclude_dims=[0] if batch else [])
if step % size_update_period == size_update_period - 1:
# this learns the overall scale on the parameter, by shrinking or
# expanding it.
param_rms = state["param_rms"]
param_rms.copy_((p ** 2).mean().sqrt().clamp_(min=eps))
param_rms.copy_(_mean(p ** 2, exclude_dims=[0] if batch else []).sqrt().clamp_(min=eps))
if step > 0:
self._size_update(p, state,
scale_grads, param_rms,
beta1, beta2, step, size_lr,
param_min_rms, param_max_rms)
param_min_rms, param_max_rms,
batch)
if numel == 1:
# For parameters with very few elements we just use a form
# of Adam with a scale factor to reflect the overall
# parameter rms. Updates delta.
self._step_scalar(scalar_max, beta1, beta2, eps, lr, p, grad, state)
self._step_scalar(scalar_max, beta1, beta2,
eps, lr, p, grad, state,
batch)
else:
if step % lr_update_period == 0 and step > 0:
self._accum_param_covs(group, p, state)
self._update_lrs(group, p, state)
self._accum_param_covs(group, p, state, batch)
self._update_lrs(group, p, state, batch)
self._zero_exp_avg_sq(state)
self._step(group, p, grad, state)
self._step(group, p, grad, state, batch)
p.add_(delta)
state["step"] = step + 1
@ -262,7 +327,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
step: int,
size_lr: float,
param_min_rms: float,
param_max_rms: float) -> None:
param_max_rms: float,
batch: bool) -> None:
"""
Called only where p.numel() > 1, this updates the scale of the parameter.
If we imagine: p = underlying_param * scale.exp(), and we are doing
@ -281,11 +347,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
size_lr: The learning rate to apply to `scale`
param_min_rms: User-specified minimum on the rms value of p, we will constrain it to >= this
param_max_rms: User-specified maximum on the rms value of p, we will constrain it to <= this
batch: If true, this is actually a batch of parameters stacked together,
and the 1st dim is the batch dim.
"""
# scale_grads is a tensor of shape (size_update_period,) containing a list of
# products (p * grad).sum() for the `size_update_period` most recent steps
size_update_period, = scale_grads.shape
if batch:
batch_size, size_update_period = scale_grads.shape
else:
# scale_grads is a tensor of shape (size_update_period,) containing a list of
# products (p * grad).sum() for the `size_update_period` most recent steps
size_update_period, = scale_grads.shape
# correct beta2 for the size update period: we will have
# faster decay at this level.
@ -293,7 +363,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
scale_exp_avg_sq = state["scale_exp_avg_sq"]
scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads ** 2).mean(), alpha=1-beta2_corr)
_mean(scale_grads ** 2, exclude_dims=[0] if batch else []),
alpha=1-beta2_corr)
# The 1st time we reach here is when size_step == 1.
@ -305,7 +376,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
eps = 1.0e-10 # this value is not critical.
denom = scale_exp_avg_sq.sqrt() + eps
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom
scale_step = -size_lr * (bias_correction2 ** 0.5) * _sum(scale_grads,
exclude_dims=[0] if batch else [],
keepdim=True) / denom
is_too_small = (param_rms < param_min_rms)
is_too_large = (param_rms > param_max_rms)
@ -320,7 +393,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
def _accum_param_covs(self,
group: dict,
p: Tensor,
state: dict) -> None:
state: dict,
batch: bool) -> None:
"""
Updates state[f"param_cov_{dim}"] for each dim of tensor p.
"""
@ -329,20 +403,27 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
step = state["step"]
this_weight = (group["param_cov_freshness"] * lr_update_period /
(step + lr_update_period))
for dim in range(p.ndim):
size = p.shape[dim]
if size == 1 or size == p.numel():
if size == 1 or size == p.numel() or (batch and dim == 0):
continue
# M is the parameter matrix, of shape (-1, size) where the -1 covers
# all other dimensions of the tensor.
M_full = p.transpose(dim, -1)
M = M_full.reshape(-1, size)
if batch: # dim 0 is batch dim
M = M_full.reshape(p.shape[0], -1, size)
else:
M = M_full.reshape(-1, size)
# will be batched matrix multiply if batch==True
this_param_cov = torch.matmul(M.t(), M)
# normalize scale of this param_cov, in casre parameter scale
# normalize scale of this param_cov, in case parameter scale
# changes significantly during training, which would cause some
# parts of the training timeline to be more highly weighted.
this_param_cov /= (this_param_cov.diag().mean() + eps)
this_param_cov /= _mean(_diag(this_param_cov),
exclude_dims=[0] if batch else [],
keepdim=True) + eps
state[f"param_cov_{dim}"].mul_(1-this_weight).add_(this_param_cov,
alpha=this_weight)
@ -359,7 +440,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
def _update_lrs(self,
group: dict,
p: Tensor,
state: dict) -> None:
state: dict,
batch: bool) -> None:
"""
Computes learning-rate matrices Q for each dim of this tensor.
Args:
@ -370,9 +452,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
state: state dict for the current parameter
"""
ndim = p.ndim
numel = p.numel()
numel = p.numel() // (p.shape[0] if batch else 1)
if numel in p.shape:
if numel in (p[0].shape if batch else p.shape)
return # Nothing to do for this parameter matrix. E.g. a bias or a scalar.
scale_arr = [None] * ndim
@ -382,21 +464,30 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
eps = 1.0e-20
cur_p = p + eps * torch.randn_like(p)
dims_to_sum = list(range(ndim-1))
for dim in range(ndim):
size = p.shape[dim]
if size == 1 or size == p.numel():
if size == 1 or size == p.numel() or (batch and dim == 0):
continue
param_cov = state[f"param_cov_{dim}"]
# if batch==True, U, S and V have an extra leading dim.
U, S, V = _svd(param_cov)
Q = state[f"Q_{dim}"]
Q[:] = U.t()
Q[:] = U.transpose(-1,-2) # 3 dims if batch, else 2.
M = cur_p.transpose(dim, -1)
if batch:
while U.ndim < M.ndim:
U = U.unsqueeze(1)
N = torch.matmul(M, U)
cur_param_var = (N * N).mean(dim=dims_to_sum)
# OK, cur_param_var would be the same as S if the variance stats
# cur_param_var is a diagonal parameter variance over dimension `dim`,
# will have shape something like [1, size, 1]; or [batch_size, 1, size, 1]
# if batch == True.
cur_param_var = _mean(N**2,
exclude_dims=[0,dim] if batch else [dim],
keepdim=True)
S = S.reshape(cur_param_var.shape)
# OK, cur_param_var would have the values as S if the variance stats
# param_cov_{dim} were accumulated from this exact parameter matrix,
# but actually they contain other versions of the parameter
# covariance so they will, in general, be less extreme. We scale
@ -404,6 +495,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# ensure it doesn't have any too-small eigenvalues (where the
# stats permit).
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
if batch:
batch_size = p.shape[0]
scale = scale.reshape(batch_size, 1, size)
else:
scale = scale.reshape(size)
N *= scale
cur_p = N.transpose(dim, -1)
@ -431,16 +528,17 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
for i in range(4): # for 4 iterations..
for dim in range(ndim):
size = p.shape[dim]
if size == 1 or size == p.numel():
if size == 1 or size == p.numel() or (batch and dim == 0):
continue
M = cur_p.transpose(dim, -1)
# correct for the fact that we have already normalized this dim in cur_p.
if cur_scales[dim] is not None:
M *= cur_scales[dim]
rms = (M**2).mean(dim=dims_to_sum).sqrt()
rms = _mean(M**2,
exclude=[0,dim] if batch else dim).sqrt()
rank = numel // size
smoothed_rms = self._smooth_param_rms(group, rms, rank)
smoothed_rms = self._smooth_param_rms(group, rms, rank, batch)
cur_scales[dim] = smoothed_rms
M /= smoothed_rms # normalize this dim..
@ -525,12 +623,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
Q[:] = torch.matmul(U.t(), Q)
def _step(self,
group: dict,
p: Tensor,
grad: Tensor,
state: dict):
state: dict,
batch: bool):
"""
This function does the core update of self._step, in the case where the tensor
has more than 1 elements. It multiplies the moving-average gradient tensor by a
@ -542,6 +640,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
p: The parameter to be updated
grad: The grad of p
state: The state-dict corresponding to parameter p
batch: if true, the first dimension of p is a batch dim, i.e. a batch of
parameter matrices.
This function modifies p.
"""
@ -557,17 +657,21 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# grad_cov_{dim}, which are for periodically updating the
# learning-rate matrices.
size = grad.shape[dim]
if size == 1 or size == p.numel():
if size == 1 or size == p.numel() or (batch and dim == 0):
continue
if step % grad_cov_period == 0 or step < grad_cov_period:
grad_cov = state[f"grad_cov_{dim}"]
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad
grad_cov = state[f"grad_cov_{dim}"] # shaped: (batch, size, size) if batch, else (size, size)
if batch:
this_m = p.transpose(-1, dim).reshape(p.shape[0], -1, size) # parameter matrix M (batch)
this_g = grad.transpose(-1, dim).reshape(p.shape[0], -1, size) # M_grad (batch)
else:
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad
# could perhaps accumulate grad_cov less frequently; it's only
# needed when we rediagonalize which is not that common.
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
grad = self._project(grad, state, forward=True)
grad = self._project(grad, state, batch, forward=True)
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
@ -583,10 +687,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
grad = grad / denom
# and project back..
grad = self._project(grad, state, forward=False)
grad = self._project(grad, state, batch, forward=False)
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
scalar_exp_avg_sq.mul_(beta2).add_((grad**2).mean(),
scalar_exp_avg_sq.mul_(beta2).add_(_mean(grad**2,
exclude_dims=[0] if batch else [],
keepdim=batch),
alpha=1-beta2)
@ -595,12 +701,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# use the same bias_correction2.
scalar_exp_avg_sq = scalar_exp_avg_sq * (1.0 / bias_correction2)
denom = scalar_exp_avg_sq.sqrt() + eps # scalar.
denom = scalar_exp_avg_sq.sqrt() + eps # scalar, or [batch_size,1,1..]
alpha = state["param_rms"] * (1-beta1) * -lr / denom
delta = state["delta"]
delta.add_(grad, alpha=alpha)
if batch:
delta.add_(grad * alpha)
else:
delta.add_(grad, alpha=alpha)
def _step_scalar(self,
@ -611,19 +720,20 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
lr: float,
p: Tensor,
grad: Tensor,
state: dict):
state: dict,
batch: bool):
"""
A form of the core update for scalar tensors, where we cannot get a good
estimate of the parameter rms.
"""
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) if batch else (0
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=1-beta2)
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
# slower update at the start will help stability anyway.
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
denom = (exp_avg_sq.sum() / (exp_avg_sq.numel() * bias_correction2)).sqrt() + eps
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
alpha = -lr * (1-beta1) / denom
delta = state["delta"]
@ -633,6 +743,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
def _project(self,
x: Tensor,
state: dict,
batch: bool,
forward: bool) -> Tensor:
"""
Multiply a tensor x by proj_{dim} on each of its dimensions.
@ -641,6 +752,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
Args:
x: The tensor to project, e.g. a gradient or a parameter change.
state: dict to look up the projections
batch: if True, the 1st dim of x is a batch dim and is not projected.
forward: if True, go in the forward directin (from canonical to diagnonalized
co-ordinates); if False, the reverse. They differ by a transpose,
not an inverse, and do not make a round trip.
@ -648,7 +761,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
numel = x.numel()
for dim in range(x.ndim):
size = x.shape[dim]
if size == 1 or size == numel:
if size == 1 or size == numel or (batch and dim == 0):
continue
Q = state[f"Q_{dim}"]
if forward:
@ -669,12 +782,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
rank: int) -> Tensor:
"""
Smooths and normalizes to mean=1 a tensor of diagonalized parameter rms values for one dimension
of a tensor. This will be used to construct a learning-rate matrix.
of a tensor, or a batch of such diagonalized rms values.
This will be used to construct a learning-rate matrix.
Args:
group: dict for configuration values
rms: a tensor with one dimension, representing diagonalized parameter
root-mean-square values.
rms: Either a tensor with shape (size,) representing diagonalized parameter
root-mean-square values, or a tensor of shape (batch_size, size).
rank: the assumed rank of the covariance matrix from which rms was derived,
used to decide how much smoothing to apply. This is actually the
minimal rank, if the parameter matrix were stay fixed during training,
@ -688,7 +802,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
rms = rms ** param_pow
smooth = (smooth0 +
(smooth1 - smooth0) * size / (size + rank))
mean = rms.mean()
batch = (rms.ndim > 1)
mean = _mean(rms,
exclude_dim=0 if batch else None,
keepdim=True)
rms += eps + smooth * mean
new_mean = (eps + (smooth + 1) * mean)
ans = rms / new_mean
@ -697,12 +814,58 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# Apply max_rms
max_rms = 5.0
ans.clamp_(max=max_rms*2)
ans /= ans.mean()
ans /= _mean(ans, exclude_dim=0 if batch else None,
keepdim=True)
ans.clamp_(max=max_rms)
ans /= ans.mean()
ans /= _mean(ans, exclude_dim=0 if batch else None,
keepdim=True)
return ans
def _diag(x: Tensor):
"""
like torch diag(), but supports batch dim, i.e. input of shape (B, M, M) returns
output of shape (B, M)
"""
if x.ndim == 3:
(B, M, M2) = x.shape
assert M == M2
stride = x.stride()
return x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2])).contiguous()
else:
return x.diag()
def _sum(x: Tensor,
exclude_dims: List[int] = [],
keepdim: bool = False):
"""
Version of torch.sum where you specify dims to exclude, not dims to sum.
"""
if len(exclude_dim) == 0:
return x.sum(keepdim=keepdim)
elif x.ndim == 1:
return x # if one dim is excluded, there are no dims to sum, and
# x.sum(dim=[]) sums all dims so we should not call sum().
exclude_dims = [i if i >= 0 else i + x.ndim for i in exclude_dims]
dims = [i for i in range(x.ndim) if i not in exclude_dim]
return x.sum(dim=dims, keepdim=keepdim)
def _mean(x: Tensor,
exclude_dims: List[int] = [],
keepdim: bool = False):
"""
Version of torch.mean where you specify dims to exclude, not dims to mean.
"""
if len(exclude_dim) == 0:
return x.mean(keepdim=keepdim)
elif x.ndim == 1:
return x # if one dim is excluded, there are no dims to mean, and
# x.mean(dim=[]) averages all dims so we should not call mean().
exclude_dims = [i if i >= 0 else i + x.ndim for i in exclude_dims]
dims = [i for i in range(x.ndim) if i not in exclude_dim]
return x.mean(dim=dims, keepdim=keepdim)