mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Replace max_fullcov_size with max_block_size
This commit is contained in:
parent
3468c3aa5a
commit
075a2e27d8
@ -128,6 +128,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
size_update_period: The periodicity, in steps, with which we update the size (scale)
|
size_update_period: The periodicity, in steps, with which we update the size (scale)
|
||||||
of the parameter tensor. This is provided to save a little time.
|
of the parameter tensor. This is provided to save a little time.
|
||||||
lr_update_period: The periodicity, in steps, with which we update the learning-rate matrices.
|
lr_update_period: The periodicity, in steps, with which we update the learning-rate matrices.
|
||||||
|
** This is important for the speed/optimizaton tradeoff. **
|
||||||
|
param_cov_period: The periodicity, in steps, with which we update the parameter covariance
|
||||||
|
stats.
|
||||||
|
max_block_size: The maximum block size in block-diagonal co-ordinate transformations.
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -147,7 +151,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
lr_update_period=200,
|
lr_update_period=200,
|
||||||
grad_cov_period=3,
|
grad_cov_period=3,
|
||||||
param_cov_period=100,
|
param_cov_period=100,
|
||||||
max_fullcov_size=2048,
|
max_block_size=1024,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
||||||
@ -168,7 +172,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
lr_update_period=lr_update_period,
|
lr_update_period=lr_update_period,
|
||||||
grad_cov_period=grad_cov_period,
|
grad_cov_period=grad_cov_period,
|
||||||
param_cov_period=param_cov_period,
|
param_cov_period=param_cov_period,
|
||||||
max_fullcov_size=max_fullcov_size,
|
max_block_size=max_block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
super(PrAdam, self).__init__(params, defaults)
|
super(PrAdam, self).__init__(params, defaults)
|
||||||
@ -226,7 +230,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
"""
|
"""
|
||||||
eps = group["eps"]
|
eps = group["eps"]
|
||||||
size_update_period = group["size_update_period"]
|
size_update_period = group["size_update_period"]
|
||||||
max_fullcov_size = group["max_fullcov_size"]
|
max_block_size = group["max_block_size"]
|
||||||
|
|
||||||
state["step"] = 0
|
state["step"] = 0
|
||||||
|
|
||||||
@ -288,34 +292,51 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
if size == 1 or (size == numel and ignore_rank1_dims):
|
if size == 1 or (size == numel and ignore_rank1_dims):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if size <= max_fullcov_size:
|
# Most of the time, (num_blocks, block_size) will equal (1, size)
|
||||||
# Q_{dim} is be the learning-rate matrix for this
|
num_blocks, block_size = self._get_block_size(size, max_block_size)
|
||||||
# dimension, a matrix of indexed [diagonalized_coordinate, canonical_coordinate].
|
|
||||||
# this is used twice in the update step, once transposed and once without transpose.
|
|
||||||
state[f"Q_{dim}"] = torch.eye(size, **kwargs).unsqueeze(0).expand(
|
|
||||||
batch_size, size, size).contiguous()
|
|
||||||
|
|
||||||
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
|
# Q_{dim} is be the learning-rate matrix for this
|
||||||
# all other dims as a batch axis.
|
# dimension, a matrix of indexed [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate].
|
||||||
state[f"param_cov_{dim}"] = torch.zeros(batch_size, size, size, **kwargs)
|
# this is used twice in the update step, once transposed and once without transpose.
|
||||||
|
Q = torch.eye(block_size, block_size, **kwargs).unsqueeze(0).unsqueeze(0).expand(
|
||||||
|
batch_size, num_blocks, block_size, block_size).contiguous()
|
||||||
|
state[f"Q_{dim}"] = Q
|
||||||
|
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
|
||||||
|
# all other dims as a batch axis.
|
||||||
|
state[f"param_cov_{dim}"] = torch.zeros_like(Q)
|
||||||
|
|
||||||
# grad_cov_{dim} is the covariance of gradients on this axis (without
|
# grad_cov_{dim} is the covariance of gradients on this axis (without
|
||||||
# any co-ordinate changes), treating all other axes as as a batch axis.
|
# any co-ordinate changes), treating all other axes as as a batch axis.
|
||||||
# This is needed when we re-diagonalize, and to compute the
|
# This is needed when we re-diagonalize, and to compute the
|
||||||
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
|
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
|
||||||
|
|
||||||
# only for purposes of computing the scalar factor on the
|
# only for purposes of computing the scalar factor on the
|
||||||
# learning rate; and, as a result of this, also contributes something
|
# learning rate; and, as a result of this, also contributes something
|
||||||
# to the gradient w.r.t. f"Q_{dim}", which is one reason
|
# to the gradient w.r.t. f"Q_{dim}", which is one reason
|
||||||
# why we allocate a variable to keep track of its moving average
|
# why we allocate a variable to keep track of its moving average
|
||||||
# instead of just using a temporary and smoothing the scalar factor.
|
# instead of just using a temporary and smoothing the scalar factor.
|
||||||
state[f"grad_cov_{dim}"] = torch.zeros(batch_size, size, size, **kwargs)
|
state[f"grad_cov_{dim}"] = torch.zeros_like(Q)
|
||||||
else:
|
|
||||||
# diagonal-only Q and param_cov, no grad_cov needed because it is only
|
|
||||||
# needed for multiplying Q_{dim} by an orthogonal matrix, which is not
|
def _get_block_size(self, size: int, max_block_size: int) -> Tuple[int,int]:
|
||||||
# applicable if Q_{dim} is diagonal.
|
"""
|
||||||
state[f"Q_{dim}"] = torch.ones(batch_size, size, **kwargs)
|
Returns information about the block size for a block-diagonal structure
|
||||||
state[f"param_cov_{dim}"] = torch.zeros(batch_size, size, **kwargs)
|
of covariances and co-ordinate transformations.
|
||||||
|
|
||||||
|
size: The size of the parameter tensor on one of its dimensions, e.g.
|
||||||
|
a channel dim or kernel size
|
||||||
|
max_block_size: The maximum block size allowed (of the blocks in a
|
||||||
|
block-diagonal structure)
|
||||||
|
Returns:
|
||||||
|
(num_blocks, block_size)
|
||||||
|
where nun_blocks * block_size == size and block_size <= max_block_size
|
||||||
|
"""
|
||||||
|
for num_blocks in range(1, size):
|
||||||
|
block_size = size // num_blocks
|
||||||
|
if block_size * num_blocks == size and block_size <= max_block_size:
|
||||||
|
return (num_blocks, block_size)
|
||||||
|
assert False, (size, max_block_size) # code error or e.g. negative or
|
||||||
|
# zero inputs
|
||||||
|
|
||||||
|
|
||||||
def _step_one_batch(self,
|
def _step_one_batch(self,
|
||||||
@ -462,30 +483,23 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
continue # e.g. size == 1 or size == numel
|
continue # e.g. size == 1 or size == numel
|
||||||
|
|
||||||
if param_cov.ndim == 3:
|
# param_cov shape: (batch_size, num_blocks, block_size, block_size)
|
||||||
# param_cov shape: (batch_size, size, size)
|
(batch_size, num_blocks, block_size, block_size) = param_cov.shape
|
||||||
|
|
||||||
# M is the parameter matrix, of shape (-1, size) where the -1 covers
|
# M: (batch_size, num_blocks, x, block_size)
|
||||||
# all other dimensions of the tensor.
|
M = p.transpose(dim, -1).reshape(batch_size, -1,
|
||||||
M_full = p.transpose(dim, -1)
|
num_blocks, block_size).transpose(1, 2)
|
||||||
M = M_full.reshape(batch_size, -1, size)
|
|
||||||
|
|
||||||
# will be batched matrix multiply. shape of this_param_cov: (batch_size, size, size)
|
# will be batched matrix multiply. shape of this_param_cov:
|
||||||
this_param_cov = torch.matmul(M.transpose(1,2), M)
|
# (batch_size, num_blocks, block_size, block_size)
|
||||||
# normalize scale of this param_cov, in case parameter scale
|
this_param_cov = torch.matmul(M.transpose(2, 3), M)
|
||||||
# changes significantly during training, which would cause some
|
|
||||||
# parts of the training timeline to be more highly weighted.
|
# normalize scale of this param_cov, in case parameter scale
|
||||||
# shape of this_param_cov
|
# changes significantly during training, which would cause some
|
||||||
this_param_cov /= _mean(_diag(this_param_cov), # _diag(this_param_cov) has shape (batch_size, size)
|
# parts of the training timeline to be more highly weighted.
|
||||||
exclude_dims=[0],
|
# shape of this_param_cov
|
||||||
keepdim=True).unsqueeze(-1) + eps # shape: (batch_size, 1, 1)
|
# expression after /= has shape (batch_size, num_blocks, 1, 1)
|
||||||
else:
|
this_param_cov /= _diag(this_param_cov).mean(dim=[0,1], keepdim=True).unsqueeze(-1) + eps
|
||||||
# this_param_cov dim: (batch_size, size)
|
|
||||||
this_param_cov = _mean(p**2,
|
|
||||||
exclude_dims=[0,dim])
|
|
||||||
this_param_cov /= _mean(this_param_cov,
|
|
||||||
exclude_dims=[0],
|
|
||||||
keepdim=True) + eps # shape: (batch_size, 1)
|
|
||||||
|
|
||||||
param_cov.mul_(1-this_weight).add_(this_param_cov,
|
param_cov.mul_(1-this_weight).add_(this_param_cov,
|
||||||
alpha=this_weight)
|
alpha=this_weight)
|
||||||
@ -543,23 +557,30 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
assert size == 1 or size == numel, size
|
assert size == 1 or size == numel, size
|
||||||
continue # e.g. size == 1 or size == numel:
|
continue # e.g. size == 1 or size == numel:
|
||||||
|
|
||||||
if param_cov.ndim == 2:
|
# param_cov has the same shape as Q
|
||||||
# size > max_fullcov_size, we do not accumulate full covariance
|
(batch_size, num_blocks, block_size, block_size) = Q.shape
|
||||||
# matrix for this dim and will use a diagonal Q.
|
# U,S,V have extra leading dimensions, i.e. of shape
|
||||||
# param_cov.shape is (batch_size, size).
|
# {U,V}.shape == (batch_size, num_blocks, block_size, block_size),
|
||||||
Q[:] = 1.0
|
# S.shape == (batch_size, num_blocks, block_size).
|
||||||
S = param_cov # (batch_size, size)
|
U, S, V = _svd(param_cov)
|
||||||
else:
|
Q[:] = U.transpose(2, 3)
|
||||||
# if batch==True, param_cov.shape == (batch_size, size, size), and
|
|
||||||
# U, S and V have an extra leading dim.
|
|
||||||
U, S, V = _svd(param_cov)
|
|
||||||
Q[:] = U.transpose(1, 2)
|
|
||||||
|
|
||||||
M = cur_p.transpose(dim, -1) # (batch_size, x, y, z, size)
|
M = cur_p.transpose(dim, -1)
|
||||||
while U.ndim < M.ndim:
|
# if p were of shape (batch_size, x, size, y, z),
|
||||||
U = U.unsqueeze(1) # (batch_size, 1, 1, size, size)
|
# after the next line M will be of shape
|
||||||
M = torch.matmul(M, U) # (batch_size, x, y, z, size)
|
# (batch_size, x, y, z, num_blocks, block_size)
|
||||||
cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z)
|
M = M.reshape(batch_size, *M.shape[1:-1],
|
||||||
|
num_blocks, block_size)
|
||||||
|
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
|
||||||
|
|
||||||
|
while U.ndim < M.ndim:
|
||||||
|
U = U.unsqueeze(1)
|
||||||
|
# Now U is of shape (batch_size, num_blocks, 1, 1, block_size, block_size)
|
||||||
|
|
||||||
|
M = torch.matmul(M, U) # (batch_size, num_blocks, x, y, z, block_size)
|
||||||
|
M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size)
|
||||||
|
M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, y, z, size)
|
||||||
|
cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z)
|
||||||
|
|
||||||
# cur_param_var is a diagonal parameter variance over dimension `dim`,
|
# cur_param_var is a diagonal parameter variance over dimension `dim`,
|
||||||
# of the current "slightly-whitened" parameter; it
|
# of the current "slightly-whitened" parameter; it
|
||||||
@ -578,8 +599,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# (where the stats permit).
|
# (where the stats permit).
|
||||||
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
|
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
|
||||||
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}")
|
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}")
|
||||||
# scale: (batch_size, 1, size, 1, 1) if dim==2
|
|
||||||
|
|
||||||
|
# scale shape: (batch_size, 1, size, 1, 1) if dim==2
|
||||||
cur_p *= scale
|
cur_p *= scale
|
||||||
|
|
||||||
# OK, at this point we have a matrix cur_p that is (somewhat)
|
# OK, at this point we have a matrix cur_p that is (somewhat)
|
||||||
@ -603,7 +624,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# cur_scales for the other dims.
|
# cur_scales for the other dims.
|
||||||
cur_scales = [None] * ndim
|
cur_scales = [None] * ndim
|
||||||
|
|
||||||
|
|
||||||
debug = True #(random.random() < 0.001)
|
debug = True #(random.random() < 0.001)
|
||||||
for i in range(4): # for 4 iterations (this is quite arbitrary)
|
for i in range(4): # for 4 iterations (this is quite arbitrary)
|
||||||
for dim in range(1, ndim):
|
for dim in range(1, ndim):
|
||||||
@ -637,12 +657,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
if cur_scales[dim] is not None:
|
if cur_scales[dim] is not None:
|
||||||
size = p.shape[dim]
|
size = p.shape[dim]
|
||||||
Q = state[f"Q_{dim}"]
|
Q = state[f"Q_{dim}"]
|
||||||
|
(batch_size, num_blocks, block_size, block_size) = Q.shape
|
||||||
|
|
||||||
scale = cur_scales[dim].reshape(batch_size, size)
|
scale = cur_scales[dim].reshape(batch_size, num_blocks, block_size, 1)
|
||||||
if Q.ndim == 3:
|
# Q is indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate],
|
||||||
# Q is indexed [batch_index, diagonalized_coordinate, canonical_coordinate],
|
# want to multiply on the diagonalized co-ordinate.
|
||||||
# want to multiply on the diagonalized co-ordinate.
|
|
||||||
scale = scale.unsqueeze(-1)
|
|
||||||
# else: Q is indexed [batch_index, canonical_coordinate].
|
# else: Q is indexed [batch_index, canonical_coordinate].
|
||||||
state[f"Q_{dim}"] *= scale
|
state[f"Q_{dim}"] *= scale
|
||||||
|
|
||||||
@ -658,15 +677,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
by left-multiplying the projections by an orthogonal matrix.
|
by left-multiplying the projections by an orthogonal matrix.
|
||||||
"""
|
"""
|
||||||
batch_size = p.shape[0]
|
batch_size = p.shape[0]
|
||||||
max_fullcov_size = group["max_fullcov_size"]
|
|
||||||
numel = p.numel() // batch_size
|
numel = p.numel() // batch_size
|
||||||
for dim in range(1, p.ndim): # dim 0 is batch dim
|
for dim in range(1, p.ndim): # dim 0 is batch dim
|
||||||
size = p.shape[dim]
|
size = p.shape[dim]
|
||||||
try:
|
try:
|
||||||
|
# A and grad_cov shape: (batch_size, num_blocks, block_size, block_size)
|
||||||
Q = state[f"Q_{dim}"]
|
Q = state[f"Q_{dim}"]
|
||||||
grad_cov = state[f"grad_cov_{dim}"] # grad_cov shape: (batch_size, size_size)
|
grad_cov = state[f"grad_cov_{dim}"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
assert size == 1 or size == numel or size > max_fullcov_size
|
assert size == 1 or size == numel
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Suppose the actual parameter matrix p is M, of shape (-1, size), where
|
# Suppose the actual parameter matrix p is M, of shape (-1, size), where
|
||||||
@ -692,15 +711,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# a ratio that, if the eigenvalues are more dispersed, it will be larger.
|
# a ratio that, if the eigenvalues are more dispersed, it will be larger.
|
||||||
return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item()
|
return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item()
|
||||||
|
|
||||||
# N_grad_cov shape: (batch_size, size, size)
|
|
||||||
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.transpose(1, 2)))
|
# N_grad_cov shape: (batch_size, num_blocks, block_size, block_size)
|
||||||
N_grad_cov = N_grad_cov + N_grad_cov.transpose(1, 2) # ensure symmetric
|
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.transpose(2, 3)))
|
||||||
|
N_grad_cov = N_grad_cov + N_grad_cov.transpose(2, 3) # ensure symmetric
|
||||||
U, S, V = _svd(N_grad_cov)
|
U, S, V = _svd(N_grad_cov)
|
||||||
if random.random() < 0.001:
|
if random.random() < 0.001:
|
||||||
logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion "
|
logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion "
|
||||||
f"changed from {dispersion(_diag(N_grad_cov))} to {dispersion(S)}")
|
f"changed from {dispersion(_diag(N_grad_cov))} to {dispersion(S)}")
|
||||||
|
|
||||||
|
|
||||||
# N_grad_cov is SPD, so
|
# N_grad_cov is SPD, so
|
||||||
# N_grad_cov = U S U^T.
|
# N_grad_cov = U S U^T.
|
||||||
|
|
||||||
@ -725,7 +744,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
#
|
#
|
||||||
# This is the only thing we have to do, as N is implicit
|
# This is the only thing we have to do, as N is implicit
|
||||||
# and not materialized at this point.
|
# and not materialized at this point.
|
||||||
Q[:] = torch.matmul(U.transpose(1, 2), Q)
|
Q[:] = torch.matmul(U.transpose(2, 3), Q)
|
||||||
|
|
||||||
def _update_grad_cov(self,
|
def _update_grad_cov(self,
|
||||||
group: dict,
|
group: dict,
|
||||||
@ -742,10 +761,24 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
size = grad.shape[dim]
|
size = grad.shape[dim]
|
||||||
name = f"grad_cov_{dim}"
|
name = f"grad_cov_{dim}"
|
||||||
if name not in state:
|
if name not in state:
|
||||||
continue # e.g. size==1, size == numel, size > max_fullcov_size
|
continue # e.g. size==1, size == numel, size > max_block_size
|
||||||
grad_cov = state[name] # shaped: (batch_size, size, size)
|
grad_cov = state[name] # shaped: (batch_size, num_blocks, block_size, block_size)
|
||||||
this_g = grad.transpose(-1, dim).reshape(batch_size, -1, size)
|
(batch_size, num_blocks, block_size, block_size) = grad_cov.shape
|
||||||
grad_cov.mul_(beta2).add_(torch.matmul(this_g.transpose(1,2), this_g))
|
|
||||||
|
g = grad.transpose(-1, dim)
|
||||||
|
# if grad were of shape (batch_size, x, size, y, z),
|
||||||
|
# and g of shape (batch_size, x, y, z, size),
|
||||||
|
# after the next line g will be of shape
|
||||||
|
# (batch_size, x, y, z, num_blocks, block_size)
|
||||||
|
g = g.reshape(batch_size, *g.shape[1:-1],
|
||||||
|
num_blocks, block_size)
|
||||||
|
g = _move_dim(g, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
|
||||||
|
g = g.reshape(batch_size, num_blocks, -1, block_size)
|
||||||
|
|
||||||
|
# this_grad_cov: (batch_size, num_blocks, block_size, block_size)
|
||||||
|
this_grad_cov = torch.matmul(g.transpose(-2, -1), g)
|
||||||
|
|
||||||
|
grad_cov.mul_(beta2).add_(this_grad_cov, alpha=(1-beta2))
|
||||||
|
|
||||||
def _step(self,
|
def _step(self,
|
||||||
group: dict,
|
group: dict,
|
||||||
@ -771,7 +804,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
lr = group["lr"]
|
lr = group["lr"]
|
||||||
beta1, beta2 = group["betas"]
|
beta1, beta2 = group["betas"]
|
||||||
eps = group["eps"]
|
eps = group["eps"]
|
||||||
grad_cov_period = group["grad_cov_period"]
|
|
||||||
step = state["step"]
|
step = state["step"]
|
||||||
|
|
||||||
grad = self._project(grad, state, forward=True)
|
grad = self._project(grad, state, forward=True)
|
||||||
@ -843,55 +875,53 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
p.add_(delta)
|
p.add_(delta)
|
||||||
|
|
||||||
def _project(self,
|
def _project(self,
|
||||||
x: Tensor,
|
M: Tensor,
|
||||||
state: dict,
|
state: dict,
|
||||||
forward: bool) -> Tensor:
|
forward: bool) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Multiply a tensor x by proj_{dim} on each of its dimensions for which we have
|
Multiply a tensor M by a projection Q on each of its dimensions (for which we have
|
||||||
a projection Q.
|
such a projection)
|
||||||
If forward == True, converts x from canonical to preconditioned/diagonalized
|
|
||||||
co-ordinates; if forward == False, does the reverse.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: The tensor to project, e.g. a gradient or a parameter change.
|
M: The tensor to project, e.g. a gradient or a parameter change.
|
||||||
state: dict to look up the projections
|
state: dict to look up the projections
|
||||||
forward: if True, go in the forward direction (from canonical to diagnonalized
|
forward: if True, go in the forward direction (from canonical to diagnonalized
|
||||||
co-ordinates); if False, the reverse. They differ by a transpose,
|
co-ordinates); if False, the reverse. They differ by a transpose,
|
||||||
not an inverse, and do not make a round trip.
|
not an inverse, and do not make a round trip.
|
||||||
"""
|
"""
|
||||||
numel = x.numel()
|
numel = M.numel()
|
||||||
for dim in range(1, x.ndim): # dim 0 is batch dim
|
|
||||||
|
for dim in range(1, M.ndim): # dim 0 is batch dim (batch of parameter
|
||||||
|
# tensors of same shape)
|
||||||
try:
|
try:
|
||||||
Q = state[f"Q_{dim}"] # shape: (batch_size, size, size)
|
Q = state[f"Q_{dim}"] # shape: (batch_size, num_blocks, block_size, block_size)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
continue # no projection for this dim
|
continue # no projection for this dim
|
||||||
|
|
||||||
fullcov = (Q.ndim == 3) # (batch_index, diagonalized_index, canonical_index), else
|
(batch_size, num_blocks, block_size, block_size) = Q.shape
|
||||||
# (batch_index, canonical_index).
|
size = M.shape[dim] # == num_blocks * block_size
|
||||||
|
|
||||||
if fullcov:
|
if forward:
|
||||||
if forward:
|
# Q is indexed
|
||||||
# Q is indexed [batch_index, diagonalized_index, canonical_index]; in the forward
|
# [batch_idx, block_idx, diagonalized_idx, canonical_idx]; in
|
||||||
# direction we want to change canonical to diagonalized index, so have
|
# the forward direction we want to change canonical to
|
||||||
# to transpose.
|
# diagonalized index, so we transpose.
|
||||||
Q = Q.transpose(1, 2)
|
Q = Q.transpose(2, 3)
|
||||||
# TODO: could possibly somehow force the output memory format to be
|
|
||||||
# unchanged.
|
# assume M currently has shape (batch_size, x, size, y, z), and dim==2.
|
||||||
while Q.ndim < x.ndim:
|
M = M.transpose(-1, dim) # (batch_size, x, y, z, size)
|
||||||
Q = Q.unsqueeze(1)
|
M = M.reshape(batch_size, *M.shape[1:-1],
|
||||||
# now x might have shape (batch_size, 3, 4, 5, 6)
|
num_blocks, block_size) # (batch_size, x, y, z, num_blocks, block_size)
|
||||||
# and Q might have shape (batch_size, 1, 1, 6, 6)
|
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
|
||||||
# ... and the output will still have shape (batch_size, 3, 4, 5, 6)
|
|
||||||
x = x.transpose(-1, dim)
|
while Q.ndim < M.ndim:
|
||||||
x = torch.matmul(x, Q)
|
Q = Q.unsqueeze(2)
|
||||||
x = x.transpose(-1, dim)
|
# now Q has shape (batch_size, num_blocks, 1, 1, block_size, block_size)
|
||||||
else:
|
M = torch.matmul(M, Q) # (batch_size, num_blocks, x, y, z, block_size)
|
||||||
shape = [1] * x.ndim
|
M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size)
|
||||||
shape[0] = x.shape[0] # batch_size
|
M = M.reshape(*M.shape[:-2], size) # (batch_size, x, y, z, size)
|
||||||
shape[dim] = x.shape[dim] # size
|
M = M.transpose(-1, dim) # (batch_size, x, size, y, z)
|
||||||
Q = Q.reshape(shape) # e.g. (batch_size, 1, size, 1, 1)
|
return M
|
||||||
x = x * Q
|
|
||||||
return x
|
|
||||||
|
|
||||||
def _smooth_param_rms(self,
|
def _smooth_param_rms(self,
|
||||||
group: dict,
|
group: dict,
|
||||||
@ -941,6 +971,31 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def _move_dim(x: Tensor, orig_dim: int, new_dim: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
Moves the position of one dimension in a Tensor, while keeping
|
||||||
|
the remaining dims in the same order. E.g. if x has
|
||||||
|
shape [10, 1, 2, 3], _move_dim(x, 3, 1) will have shape
|
||||||
|
[10, 3, 1, 2]. Returns the reshaped tensor which will be
|
||||||
|
a view of the original Tensor.
|
||||||
|
Args:
|
||||||
|
x: Tensor to reshape
|
||||||
|
orig_dim: Original dimension to move, with -x.ndim <= orig_dim < x.ndim
|
||||||
|
new_dim: New position of the original dimension, with
|
||||||
|
-x.ndim <= new_dim < x.ndim
|
||||||
|
Returns: a permuted view of x
|
||||||
|
"""
|
||||||
|
if orig_dim < 0:
|
||||||
|
orig_dim += x.ndim
|
||||||
|
if new_dim < 0:
|
||||||
|
new_dim += x.ndim
|
||||||
|
dims = list(range(x.ndim))
|
||||||
|
dims[orig_dim] = -1
|
||||||
|
dims.remove(-1)
|
||||||
|
dims.insert(new_dim, orig_dim)
|
||||||
|
return x.permute(dims)
|
||||||
|
|
||||||
|
|
||||||
def _diag(x: Tensor):
|
def _diag(x: Tensor):
|
||||||
"""
|
"""
|
||||||
like torch diag(), but supports batch dim, i.e. input of shape (B, M, M) returns
|
like torch diag(), but supports batch dim, i.e. input of shape (B, M, M) returns
|
||||||
@ -951,6 +1006,11 @@ def _diag(x: Tensor):
|
|||||||
assert M == M2
|
assert M == M2
|
||||||
stride = x.stride()
|
stride = x.stride()
|
||||||
return x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2])).contiguous()
|
return x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2])).contiguous()
|
||||||
|
elif x.ndim == 4:
|
||||||
|
(B, C, M, M2) = x.shape
|
||||||
|
assert M == M2
|
||||||
|
stride = x.stride()
|
||||||
|
return x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3])).contiguous()
|
||||||
else:
|
else:
|
||||||
return x.diag()
|
return x.diag()
|
||||||
|
|
||||||
@ -1676,7 +1736,7 @@ def _test_eve_cain():
|
|||||||
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
|
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
|
||||||
elif iter == 1: optim = Cain(m.parameters(), lr=0.03)
|
elif iter == 1: optim = Cain(m.parameters(), lr=0.03)
|
||||||
elif iter == 2: optim = PrAdam(m.parameters(), lr=0.03)
|
elif iter == 2: optim = PrAdam(m.parameters(), lr=0.03)
|
||||||
elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_fullcov_size=150)
|
elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_block_size=100)
|
||||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
||||||
|
|
||||||
start = timeit.default_timer()
|
start = timeit.default_timer()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user