Replace max_fullcov_size with max_block_size

This commit is contained in:
Daniel Povey 2022-07-11 16:37:01 -07:00
parent 3468c3aa5a
commit 075a2e27d8

View File

@ -128,6 +128,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
size_update_period: The periodicity, in steps, with which we update the size (scale)
of the parameter tensor. This is provided to save a little time.
lr_update_period: The periodicity, in steps, with which we update the learning-rate matrices.
** This is important for the speed/optimizaton tradeoff. **
param_cov_period: The periodicity, in steps, with which we update the parameter covariance
stats.
max_block_size: The maximum block size in block-diagonal co-ordinate transformations.
"""
def __init__(
self,
@ -147,7 +151,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
lr_update_period=200,
grad_cov_period=3,
param_cov_period=100,
max_fullcov_size=2048,
max_block_size=1024,
):
@ -168,7 +172,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
lr_update_period=lr_update_period,
grad_cov_period=grad_cov_period,
param_cov_period=param_cov_period,
max_fullcov_size=max_fullcov_size,
max_block_size=max_block_size,
)
super(PrAdam, self).__init__(params, defaults)
@ -226,7 +230,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
"""
eps = group["eps"]
size_update_period = group["size_update_period"]
max_fullcov_size = group["max_fullcov_size"]
max_block_size = group["max_block_size"]
state["step"] = 0
@ -288,34 +292,51 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
if size == 1 or (size == numel and ignore_rank1_dims):
continue
if size <= max_fullcov_size:
# Q_{dim} is be the learning-rate matrix for this
# dimension, a matrix of indexed [diagonalized_coordinate, canonical_coordinate].
# this is used twice in the update step, once transposed and once without transpose.
state[f"Q_{dim}"] = torch.eye(size, **kwargs).unsqueeze(0).expand(
batch_size, size, size).contiguous()
# Most of the time, (num_blocks, block_size) will equal (1, size)
num_blocks, block_size = self._get_block_size(size, max_block_size)
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
# all other dims as a batch axis.
state[f"param_cov_{dim}"] = torch.zeros(batch_size, size, size, **kwargs)
# Q_{dim} is be the learning-rate matrix for this
# dimension, a matrix of indexed [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate].
# this is used twice in the update step, once transposed and once without transpose.
Q = torch.eye(block_size, block_size, **kwargs).unsqueeze(0).unsqueeze(0).expand(
batch_size, num_blocks, block_size, block_size).contiguous()
state[f"Q_{dim}"] = Q
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
# all other dims as a batch axis.
state[f"param_cov_{dim}"] = torch.zeros_like(Q)
# grad_cov_{dim} is the covariance of gradients on this axis (without
# any co-ordinate changes), treating all other axes as as a batch axis.
# This is needed when we re-diagonalize, and to compute the
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
# grad_cov_{dim} is the covariance of gradients on this axis (without
# any co-ordinate changes), treating all other axes as as a batch axis.
# This is needed when we re-diagonalize, and to compute the
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
# only for purposes of computing the scalar factor on the
# learning rate; and, as a result of this, also contributes something
# to the gradient w.r.t. f"Q_{dim}", which is one reason
# why we allocate a variable to keep track of its moving average
# instead of just using a temporary and smoothing the scalar factor.
state[f"grad_cov_{dim}"] = torch.zeros(batch_size, size, size, **kwargs)
else:
# diagonal-only Q and param_cov, no grad_cov needed because it is only
# needed for multiplying Q_{dim} by an orthogonal matrix, which is not
# applicable if Q_{dim} is diagonal.
state[f"Q_{dim}"] = torch.ones(batch_size, size, **kwargs)
state[f"param_cov_{dim}"] = torch.zeros(batch_size, size, **kwargs)
# only for purposes of computing the scalar factor on the
# learning rate; and, as a result of this, also contributes something
# to the gradient w.r.t. f"Q_{dim}", which is one reason
# why we allocate a variable to keep track of its moving average
# instead of just using a temporary and smoothing the scalar factor.
state[f"grad_cov_{dim}"] = torch.zeros_like(Q)
def _get_block_size(self, size: int, max_block_size: int) -> Tuple[int,int]:
"""
Returns information about the block size for a block-diagonal structure
of covariances and co-ordinate transformations.
size: The size of the parameter tensor on one of its dimensions, e.g.
a channel dim or kernel size
max_block_size: The maximum block size allowed (of the blocks in a
block-diagonal structure)
Returns:
(num_blocks, block_size)
where nun_blocks * block_size == size and block_size <= max_block_size
"""
for num_blocks in range(1, size):
block_size = size // num_blocks
if block_size * num_blocks == size and block_size <= max_block_size:
return (num_blocks, block_size)
assert False, (size, max_block_size) # code error or e.g. negative or
# zero inputs
def _step_one_batch(self,
@ -462,30 +483,23 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
except KeyError:
continue # e.g. size == 1 or size == numel
if param_cov.ndim == 3:
# param_cov shape: (batch_size, size, size)
# param_cov shape: (batch_size, num_blocks, block_size, block_size)
(batch_size, num_blocks, block_size, block_size) = param_cov.shape
# M is the parameter matrix, of shape (-1, size) where the -1 covers
# all other dimensions of the tensor.
M_full = p.transpose(dim, -1)
M = M_full.reshape(batch_size, -1, size)
# M: (batch_size, num_blocks, x, block_size)
M = p.transpose(dim, -1).reshape(batch_size, -1,
num_blocks, block_size).transpose(1, 2)
# will be batched matrix multiply. shape of this_param_cov: (batch_size, size, size)
this_param_cov = torch.matmul(M.transpose(1,2), M)
# normalize scale of this param_cov, in case parameter scale
# changes significantly during training, which would cause some
# parts of the training timeline to be more highly weighted.
# shape of this_param_cov
this_param_cov /= _mean(_diag(this_param_cov), # _diag(this_param_cov) has shape (batch_size, size)
exclude_dims=[0],
keepdim=True).unsqueeze(-1) + eps # shape: (batch_size, 1, 1)
else:
# this_param_cov dim: (batch_size, size)
this_param_cov = _mean(p**2,
exclude_dims=[0,dim])
this_param_cov /= _mean(this_param_cov,
exclude_dims=[0],
keepdim=True) + eps # shape: (batch_size, 1)
# will be batched matrix multiply. shape of this_param_cov:
# (batch_size, num_blocks, block_size, block_size)
this_param_cov = torch.matmul(M.transpose(2, 3), M)
# normalize scale of this param_cov, in case parameter scale
# changes significantly during training, which would cause some
# parts of the training timeline to be more highly weighted.
# shape of this_param_cov
# expression after /= has shape (batch_size, num_blocks, 1, 1)
this_param_cov /= _diag(this_param_cov).mean(dim=[0,1], keepdim=True).unsqueeze(-1) + eps
param_cov.mul_(1-this_weight).add_(this_param_cov,
alpha=this_weight)
@ -543,23 +557,30 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
assert size == 1 or size == numel, size
continue # e.g. size == 1 or size == numel:
if param_cov.ndim == 2:
# size > max_fullcov_size, we do not accumulate full covariance
# matrix for this dim and will use a diagonal Q.
# param_cov.shape is (batch_size, size).
Q[:] = 1.0
S = param_cov # (batch_size, size)
else:
# if batch==True, param_cov.shape == (batch_size, size, size), and
# U, S and V have an extra leading dim.
U, S, V = _svd(param_cov)
Q[:] = U.transpose(1, 2)
# param_cov has the same shape as Q
(batch_size, num_blocks, block_size, block_size) = Q.shape
# U,S,V have extra leading dimensions, i.e. of shape
# {U,V}.shape == (batch_size, num_blocks, block_size, block_size),
# S.shape == (batch_size, num_blocks, block_size).
U, S, V = _svd(param_cov)
Q[:] = U.transpose(2, 3)
M = cur_p.transpose(dim, -1) # (batch_size, x, y, z, size)
while U.ndim < M.ndim:
U = U.unsqueeze(1) # (batch_size, 1, 1, size, size)
M = torch.matmul(M, U) # (batch_size, x, y, z, size)
cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z)
M = cur_p.transpose(dim, -1)
# if p were of shape (batch_size, x, size, y, z),
# after the next line M will be of shape
# (batch_size, x, y, z, num_blocks, block_size)
M = M.reshape(batch_size, *M.shape[1:-1],
num_blocks, block_size)
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
while U.ndim < M.ndim:
U = U.unsqueeze(1)
# Now U is of shape (batch_size, num_blocks, 1, 1, block_size, block_size)
M = torch.matmul(M, U) # (batch_size, num_blocks, x, y, z, block_size)
M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size)
M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, y, z, size)
cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z)
# cur_param_var is a diagonal parameter variance over dimension `dim`,
# of the current "slightly-whitened" parameter; it
@ -578,8 +599,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# (where the stats permit).
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}")
# scale: (batch_size, 1, size, 1, 1) if dim==2
# scale shape: (batch_size, 1, size, 1, 1) if dim==2
cur_p *= scale
# OK, at this point we have a matrix cur_p that is (somewhat)
@ -603,7 +624,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# cur_scales for the other dims.
cur_scales = [None] * ndim
debug = True #(random.random() < 0.001)
for i in range(4): # for 4 iterations (this is quite arbitrary)
for dim in range(1, ndim):
@ -637,12 +657,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
if cur_scales[dim] is not None:
size = p.shape[dim]
Q = state[f"Q_{dim}"]
(batch_size, num_blocks, block_size, block_size) = Q.shape
scale = cur_scales[dim].reshape(batch_size, size)
if Q.ndim == 3:
# Q is indexed [batch_index, diagonalized_coordinate, canonical_coordinate],
# want to multiply on the diagonalized co-ordinate.
scale = scale.unsqueeze(-1)
scale = cur_scales[dim].reshape(batch_size, num_blocks, block_size, 1)
# Q is indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate],
# want to multiply on the diagonalized co-ordinate.
# else: Q is indexed [batch_index, canonical_coordinate].
state[f"Q_{dim}"] *= scale
@ -658,15 +677,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
by left-multiplying the projections by an orthogonal matrix.
"""
batch_size = p.shape[0]
max_fullcov_size = group["max_fullcov_size"]
numel = p.numel() // batch_size
for dim in range(1, p.ndim): # dim 0 is batch dim
size = p.shape[dim]
try:
# A and grad_cov shape: (batch_size, num_blocks, block_size, block_size)
Q = state[f"Q_{dim}"]
grad_cov = state[f"grad_cov_{dim}"] # grad_cov shape: (batch_size, size_size)
grad_cov = state[f"grad_cov_{dim}"]
except KeyError:
assert size == 1 or size == numel or size > max_fullcov_size
assert size == 1 or size == numel
continue
# Suppose the actual parameter matrix p is M, of shape (-1, size), where
@ -692,15 +711,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# a ratio that, if the eigenvalues are more dispersed, it will be larger.
return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item()
# N_grad_cov shape: (batch_size, size, size)
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.transpose(1, 2)))
N_grad_cov = N_grad_cov + N_grad_cov.transpose(1, 2) # ensure symmetric
# N_grad_cov shape: (batch_size, num_blocks, block_size, block_size)
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.transpose(2, 3)))
N_grad_cov = N_grad_cov + N_grad_cov.transpose(2, 3) # ensure symmetric
U, S, V = _svd(N_grad_cov)
if random.random() < 0.001:
logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion "
f"changed from {dispersion(_diag(N_grad_cov))} to {dispersion(S)}")
# N_grad_cov is SPD, so
# N_grad_cov = U S U^T.
@ -725,7 +744,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
#
# This is the only thing we have to do, as N is implicit
# and not materialized at this point.
Q[:] = torch.matmul(U.transpose(1, 2), Q)
Q[:] = torch.matmul(U.transpose(2, 3), Q)
def _update_grad_cov(self,
group: dict,
@ -742,10 +761,24 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
size = grad.shape[dim]
name = f"grad_cov_{dim}"
if name not in state:
continue # e.g. size==1, size == numel, size > max_fullcov_size
grad_cov = state[name] # shaped: (batch_size, size, size)
this_g = grad.transpose(-1, dim).reshape(batch_size, -1, size)
grad_cov.mul_(beta2).add_(torch.matmul(this_g.transpose(1,2), this_g))
continue # e.g. size==1, size == numel, size > max_block_size
grad_cov = state[name] # shaped: (batch_size, num_blocks, block_size, block_size)
(batch_size, num_blocks, block_size, block_size) = grad_cov.shape
g = grad.transpose(-1, dim)
# if grad were of shape (batch_size, x, size, y, z),
# and g of shape (batch_size, x, y, z, size),
# after the next line g will be of shape
# (batch_size, x, y, z, num_blocks, block_size)
g = g.reshape(batch_size, *g.shape[1:-1],
num_blocks, block_size)
g = _move_dim(g, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
g = g.reshape(batch_size, num_blocks, -1, block_size)
# this_grad_cov: (batch_size, num_blocks, block_size, block_size)
this_grad_cov = torch.matmul(g.transpose(-2, -1), g)
grad_cov.mul_(beta2).add_(this_grad_cov, alpha=(1-beta2))
def _step(self,
group: dict,
@ -771,7 +804,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
lr = group["lr"]
beta1, beta2 = group["betas"]
eps = group["eps"]
grad_cov_period = group["grad_cov_period"]
step = state["step"]
grad = self._project(grad, state, forward=True)
@ -843,55 +875,53 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
p.add_(delta)
def _project(self,
x: Tensor,
M: Tensor,
state: dict,
forward: bool) -> Tensor:
"""
Multiply a tensor x by proj_{dim} on each of its dimensions for which we have
a projection Q.
If forward == True, converts x from canonical to preconditioned/diagonalized
co-ordinates; if forward == False, does the reverse.
Multiply a tensor M by a projection Q on each of its dimensions (for which we have
such a projection)
Args:
x: The tensor to project, e.g. a gradient or a parameter change.
M: The tensor to project, e.g. a gradient or a parameter change.
state: dict to look up the projections
forward: if True, go in the forward direction (from canonical to diagnonalized
co-ordinates); if False, the reverse. They differ by a transpose,
not an inverse, and do not make a round trip.
"""
numel = x.numel()
for dim in range(1, x.ndim): # dim 0 is batch dim
numel = M.numel()
for dim in range(1, M.ndim): # dim 0 is batch dim (batch of parameter
# tensors of same shape)
try:
Q = state[f"Q_{dim}"] # shape: (batch_size, size, size)
Q = state[f"Q_{dim}"] # shape: (batch_size, num_blocks, block_size, block_size)
except KeyError:
continue # no projection for this dim
fullcov = (Q.ndim == 3) # (batch_index, diagonalized_index, canonical_index), else
# (batch_index, canonical_index).
(batch_size, num_blocks, block_size, block_size) = Q.shape
size = M.shape[dim] # == num_blocks * block_size
if fullcov:
if forward:
# Q is indexed [batch_index, diagonalized_index, canonical_index]; in the forward
# direction we want to change canonical to diagonalized index, so have
# to transpose.
Q = Q.transpose(1, 2)
# TODO: could possibly somehow force the output memory format to be
# unchanged.
while Q.ndim < x.ndim:
Q = Q.unsqueeze(1)
# now x might have shape (batch_size, 3, 4, 5, 6)
# and Q might have shape (batch_size, 1, 1, 6, 6)
# ... and the output will still have shape (batch_size, 3, 4, 5, 6)
x = x.transpose(-1, dim)
x = torch.matmul(x, Q)
x = x.transpose(-1, dim)
else:
shape = [1] * x.ndim
shape[0] = x.shape[0] # batch_size
shape[dim] = x.shape[dim] # size
Q = Q.reshape(shape) # e.g. (batch_size, 1, size, 1, 1)
x = x * Q
return x
if forward:
# Q is indexed
# [batch_idx, block_idx, diagonalized_idx, canonical_idx]; in
# the forward direction we want to change canonical to
# diagonalized index, so we transpose.
Q = Q.transpose(2, 3)
# assume M currently has shape (batch_size, x, size, y, z), and dim==2.
M = M.transpose(-1, dim) # (batch_size, x, y, z, size)
M = M.reshape(batch_size, *M.shape[1:-1],
num_blocks, block_size) # (batch_size, x, y, z, num_blocks, block_size)
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
while Q.ndim < M.ndim:
Q = Q.unsqueeze(2)
# now Q has shape (batch_size, num_blocks, 1, 1, block_size, block_size)
M = torch.matmul(M, Q) # (batch_size, num_blocks, x, y, z, block_size)
M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size)
M = M.reshape(*M.shape[:-2], size) # (batch_size, x, y, z, size)
M = M.transpose(-1, dim) # (batch_size, x, size, y, z)
return M
def _smooth_param_rms(self,
group: dict,
@ -941,6 +971,31 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
return ans
def _move_dim(x: Tensor, orig_dim: int, new_dim: int) -> Tensor:
"""
Moves the position of one dimension in a Tensor, while keeping
the remaining dims in the same order. E.g. if x has
shape [10, 1, 2, 3], _move_dim(x, 3, 1) will have shape
[10, 3, 1, 2]. Returns the reshaped tensor which will be
a view of the original Tensor.
Args:
x: Tensor to reshape
orig_dim: Original dimension to move, with -x.ndim <= orig_dim < x.ndim
new_dim: New position of the original dimension, with
-x.ndim <= new_dim < x.ndim
Returns: a permuted view of x
"""
if orig_dim < 0:
orig_dim += x.ndim
if new_dim < 0:
new_dim += x.ndim
dims = list(range(x.ndim))
dims[orig_dim] = -1
dims.remove(-1)
dims.insert(new_dim, orig_dim)
return x.permute(dims)
def _diag(x: Tensor):
"""
like torch diag(), but supports batch dim, i.e. input of shape (B, M, M) returns
@ -951,6 +1006,11 @@ def _diag(x: Tensor):
assert M == M2
stride = x.stride()
return x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2])).contiguous()
elif x.ndim == 4:
(B, C, M, M2) = x.shape
assert M == M2
stride = x.stride()
return x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3])).contiguous()
else:
return x.diag()
@ -1676,7 +1736,7 @@ def _test_eve_cain():
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
elif iter == 1: optim = Cain(m.parameters(), lr=0.03)
elif iter == 2: optim = PrAdam(m.parameters(), lr=0.03)
elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_fullcov_size=150)
elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_block_size=100)
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
start = timeit.default_timer()