mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Replace max_fullcov_size with max_block_size
This commit is contained in:
parent
3468c3aa5a
commit
075a2e27d8
@ -128,6 +128,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
size_update_period: The periodicity, in steps, with which we update the size (scale)
|
||||
of the parameter tensor. This is provided to save a little time.
|
||||
lr_update_period: The periodicity, in steps, with which we update the learning-rate matrices.
|
||||
** This is important for the speed/optimizaton tradeoff. **
|
||||
param_cov_period: The periodicity, in steps, with which we update the parameter covariance
|
||||
stats.
|
||||
max_block_size: The maximum block size in block-diagonal co-ordinate transformations.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
@ -147,7 +151,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
lr_update_period=200,
|
||||
grad_cov_period=3,
|
||||
param_cov_period=100,
|
||||
max_fullcov_size=2048,
|
||||
max_block_size=1024,
|
||||
):
|
||||
|
||||
|
||||
@ -168,7 +172,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
lr_update_period=lr_update_period,
|
||||
grad_cov_period=grad_cov_period,
|
||||
param_cov_period=param_cov_period,
|
||||
max_fullcov_size=max_fullcov_size,
|
||||
max_block_size=max_block_size,
|
||||
)
|
||||
|
||||
super(PrAdam, self).__init__(params, defaults)
|
||||
@ -226,7 +230,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
"""
|
||||
eps = group["eps"]
|
||||
size_update_period = group["size_update_period"]
|
||||
max_fullcov_size = group["max_fullcov_size"]
|
||||
max_block_size = group["max_block_size"]
|
||||
|
||||
state["step"] = 0
|
||||
|
||||
@ -288,16 +292,18 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
if size == 1 or (size == numel and ignore_rank1_dims):
|
||||
continue
|
||||
|
||||
if size <= max_fullcov_size:
|
||||
# Q_{dim} is be the learning-rate matrix for this
|
||||
# dimension, a matrix of indexed [diagonalized_coordinate, canonical_coordinate].
|
||||
# this is used twice in the update step, once transposed and once without transpose.
|
||||
state[f"Q_{dim}"] = torch.eye(size, **kwargs).unsqueeze(0).expand(
|
||||
batch_size, size, size).contiguous()
|
||||
# Most of the time, (num_blocks, block_size) will equal (1, size)
|
||||
num_blocks, block_size = self._get_block_size(size, max_block_size)
|
||||
|
||||
# Q_{dim} is be the learning-rate matrix for this
|
||||
# dimension, a matrix of indexed [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate].
|
||||
# this is used twice in the update step, once transposed and once without transpose.
|
||||
Q = torch.eye(block_size, block_size, **kwargs).unsqueeze(0).unsqueeze(0).expand(
|
||||
batch_size, num_blocks, block_size, block_size).contiguous()
|
||||
state[f"Q_{dim}"] = Q
|
||||
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
|
||||
# all other dims as a batch axis.
|
||||
state[f"param_cov_{dim}"] = torch.zeros(batch_size, size, size, **kwargs)
|
||||
state[f"param_cov_{dim}"] = torch.zeros_like(Q)
|
||||
|
||||
# grad_cov_{dim} is the covariance of gradients on this axis (without
|
||||
# any co-ordinate changes), treating all other axes as as a batch axis.
|
||||
@ -309,13 +315,28 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# to the gradient w.r.t. f"Q_{dim}", which is one reason
|
||||
# why we allocate a variable to keep track of its moving average
|
||||
# instead of just using a temporary and smoothing the scalar factor.
|
||||
state[f"grad_cov_{dim}"] = torch.zeros(batch_size, size, size, **kwargs)
|
||||
else:
|
||||
# diagonal-only Q and param_cov, no grad_cov needed because it is only
|
||||
# needed for multiplying Q_{dim} by an orthogonal matrix, which is not
|
||||
# applicable if Q_{dim} is diagonal.
|
||||
state[f"Q_{dim}"] = torch.ones(batch_size, size, **kwargs)
|
||||
state[f"param_cov_{dim}"] = torch.zeros(batch_size, size, **kwargs)
|
||||
state[f"grad_cov_{dim}"] = torch.zeros_like(Q)
|
||||
|
||||
|
||||
def _get_block_size(self, size: int, max_block_size: int) -> Tuple[int,int]:
|
||||
"""
|
||||
Returns information about the block size for a block-diagonal structure
|
||||
of covariances and co-ordinate transformations.
|
||||
|
||||
size: The size of the parameter tensor on one of its dimensions, e.g.
|
||||
a channel dim or kernel size
|
||||
max_block_size: The maximum block size allowed (of the blocks in a
|
||||
block-diagonal structure)
|
||||
Returns:
|
||||
(num_blocks, block_size)
|
||||
where nun_blocks * block_size == size and block_size <= max_block_size
|
||||
"""
|
||||
for num_blocks in range(1, size):
|
||||
block_size = size // num_blocks
|
||||
if block_size * num_blocks == size and block_size <= max_block_size:
|
||||
return (num_blocks, block_size)
|
||||
assert False, (size, max_block_size) # code error or e.g. negative or
|
||||
# zero inputs
|
||||
|
||||
|
||||
def _step_one_batch(self,
|
||||
@ -462,30 +483,23 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
except KeyError:
|
||||
continue # e.g. size == 1 or size == numel
|
||||
|
||||
if param_cov.ndim == 3:
|
||||
# param_cov shape: (batch_size, size, size)
|
||||
# param_cov shape: (batch_size, num_blocks, block_size, block_size)
|
||||
(batch_size, num_blocks, block_size, block_size) = param_cov.shape
|
||||
|
||||
# M is the parameter matrix, of shape (-1, size) where the -1 covers
|
||||
# all other dimensions of the tensor.
|
||||
M_full = p.transpose(dim, -1)
|
||||
M = M_full.reshape(batch_size, -1, size)
|
||||
# M: (batch_size, num_blocks, x, block_size)
|
||||
M = p.transpose(dim, -1).reshape(batch_size, -1,
|
||||
num_blocks, block_size).transpose(1, 2)
|
||||
|
||||
# will be batched matrix multiply. shape of this_param_cov:
|
||||
# (batch_size, num_blocks, block_size, block_size)
|
||||
this_param_cov = torch.matmul(M.transpose(2, 3), M)
|
||||
|
||||
# will be batched matrix multiply. shape of this_param_cov: (batch_size, size, size)
|
||||
this_param_cov = torch.matmul(M.transpose(1,2), M)
|
||||
# normalize scale of this param_cov, in case parameter scale
|
||||
# changes significantly during training, which would cause some
|
||||
# parts of the training timeline to be more highly weighted.
|
||||
# shape of this_param_cov
|
||||
this_param_cov /= _mean(_diag(this_param_cov), # _diag(this_param_cov) has shape (batch_size, size)
|
||||
exclude_dims=[0],
|
||||
keepdim=True).unsqueeze(-1) + eps # shape: (batch_size, 1, 1)
|
||||
else:
|
||||
# this_param_cov dim: (batch_size, size)
|
||||
this_param_cov = _mean(p**2,
|
||||
exclude_dims=[0,dim])
|
||||
this_param_cov /= _mean(this_param_cov,
|
||||
exclude_dims=[0],
|
||||
keepdim=True) + eps # shape: (batch_size, 1)
|
||||
# expression after /= has shape (batch_size, num_blocks, 1, 1)
|
||||
this_param_cov /= _diag(this_param_cov).mean(dim=[0,1], keepdim=True).unsqueeze(-1) + eps
|
||||
|
||||
param_cov.mul_(1-this_weight).add_(this_param_cov,
|
||||
alpha=this_weight)
|
||||
@ -543,22 +557,29 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
assert size == 1 or size == numel, size
|
||||
continue # e.g. size == 1 or size == numel:
|
||||
|
||||
if param_cov.ndim == 2:
|
||||
# size > max_fullcov_size, we do not accumulate full covariance
|
||||
# matrix for this dim and will use a diagonal Q.
|
||||
# param_cov.shape is (batch_size, size).
|
||||
Q[:] = 1.0
|
||||
S = param_cov # (batch_size, size)
|
||||
else:
|
||||
# if batch==True, param_cov.shape == (batch_size, size, size), and
|
||||
# U, S and V have an extra leading dim.
|
||||
# param_cov has the same shape as Q
|
||||
(batch_size, num_blocks, block_size, block_size) = Q.shape
|
||||
# U,S,V have extra leading dimensions, i.e. of shape
|
||||
# {U,V}.shape == (batch_size, num_blocks, block_size, block_size),
|
||||
# S.shape == (batch_size, num_blocks, block_size).
|
||||
U, S, V = _svd(param_cov)
|
||||
Q[:] = U.transpose(1, 2)
|
||||
Q[:] = U.transpose(2, 3)
|
||||
|
||||
M = cur_p.transpose(dim, -1)
|
||||
# if p were of shape (batch_size, x, size, y, z),
|
||||
# after the next line M will be of shape
|
||||
# (batch_size, x, y, z, num_blocks, block_size)
|
||||
M = M.reshape(batch_size, *M.shape[1:-1],
|
||||
num_blocks, block_size)
|
||||
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
|
||||
|
||||
M = cur_p.transpose(dim, -1) # (batch_size, x, y, z, size)
|
||||
while U.ndim < M.ndim:
|
||||
U = U.unsqueeze(1) # (batch_size, 1, 1, size, size)
|
||||
M = torch.matmul(M, U) # (batch_size, x, y, z, size)
|
||||
U = U.unsqueeze(1)
|
||||
# Now U is of shape (batch_size, num_blocks, 1, 1, block_size, block_size)
|
||||
|
||||
M = torch.matmul(M, U) # (batch_size, num_blocks, x, y, z, block_size)
|
||||
M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size)
|
||||
M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, y, z, size)
|
||||
cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z)
|
||||
|
||||
# cur_param_var is a diagonal parameter variance over dimension `dim`,
|
||||
@ -578,8 +599,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# (where the stats permit).
|
||||
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
|
||||
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}")
|
||||
# scale: (batch_size, 1, size, 1, 1) if dim==2
|
||||
|
||||
# scale shape: (batch_size, 1, size, 1, 1) if dim==2
|
||||
cur_p *= scale
|
||||
|
||||
# OK, at this point we have a matrix cur_p that is (somewhat)
|
||||
@ -603,7 +624,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# cur_scales for the other dims.
|
||||
cur_scales = [None] * ndim
|
||||
|
||||
|
||||
debug = True #(random.random() < 0.001)
|
||||
for i in range(4): # for 4 iterations (this is quite arbitrary)
|
||||
for dim in range(1, ndim):
|
||||
@ -637,12 +657,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
if cur_scales[dim] is not None:
|
||||
size = p.shape[dim]
|
||||
Q = state[f"Q_{dim}"]
|
||||
(batch_size, num_blocks, block_size, block_size) = Q.shape
|
||||
|
||||
scale = cur_scales[dim].reshape(batch_size, size)
|
||||
if Q.ndim == 3:
|
||||
# Q is indexed [batch_index, diagonalized_coordinate, canonical_coordinate],
|
||||
scale = cur_scales[dim].reshape(batch_size, num_blocks, block_size, 1)
|
||||
# Q is indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate],
|
||||
# want to multiply on the diagonalized co-ordinate.
|
||||
scale = scale.unsqueeze(-1)
|
||||
# else: Q is indexed [batch_index, canonical_coordinate].
|
||||
state[f"Q_{dim}"] *= scale
|
||||
|
||||
@ -658,15 +677,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
by left-multiplying the projections by an orthogonal matrix.
|
||||
"""
|
||||
batch_size = p.shape[0]
|
||||
max_fullcov_size = group["max_fullcov_size"]
|
||||
numel = p.numel() // batch_size
|
||||
for dim in range(1, p.ndim): # dim 0 is batch dim
|
||||
size = p.shape[dim]
|
||||
try:
|
||||
# A and grad_cov shape: (batch_size, num_blocks, block_size, block_size)
|
||||
Q = state[f"Q_{dim}"]
|
||||
grad_cov = state[f"grad_cov_{dim}"] # grad_cov shape: (batch_size, size_size)
|
||||
grad_cov = state[f"grad_cov_{dim}"]
|
||||
except KeyError:
|
||||
assert size == 1 or size == numel or size > max_fullcov_size
|
||||
assert size == 1 or size == numel
|
||||
continue
|
||||
|
||||
# Suppose the actual parameter matrix p is M, of shape (-1, size), where
|
||||
@ -692,15 +711,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# a ratio that, if the eigenvalues are more dispersed, it will be larger.
|
||||
return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item()
|
||||
|
||||
# N_grad_cov shape: (batch_size, size, size)
|
||||
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.transpose(1, 2)))
|
||||
N_grad_cov = N_grad_cov + N_grad_cov.transpose(1, 2) # ensure symmetric
|
||||
|
||||
# N_grad_cov shape: (batch_size, num_blocks, block_size, block_size)
|
||||
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.transpose(2, 3)))
|
||||
N_grad_cov = N_grad_cov + N_grad_cov.transpose(2, 3) # ensure symmetric
|
||||
U, S, V = _svd(N_grad_cov)
|
||||
if random.random() < 0.001:
|
||||
logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion "
|
||||
f"changed from {dispersion(_diag(N_grad_cov))} to {dispersion(S)}")
|
||||
|
||||
|
||||
# N_grad_cov is SPD, so
|
||||
# N_grad_cov = U S U^T.
|
||||
|
||||
@ -725,7 +744,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
#
|
||||
# This is the only thing we have to do, as N is implicit
|
||||
# and not materialized at this point.
|
||||
Q[:] = torch.matmul(U.transpose(1, 2), Q)
|
||||
Q[:] = torch.matmul(U.transpose(2, 3), Q)
|
||||
|
||||
def _update_grad_cov(self,
|
||||
group: dict,
|
||||
@ -742,10 +761,24 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
size = grad.shape[dim]
|
||||
name = f"grad_cov_{dim}"
|
||||
if name not in state:
|
||||
continue # e.g. size==1, size == numel, size > max_fullcov_size
|
||||
grad_cov = state[name] # shaped: (batch_size, size, size)
|
||||
this_g = grad.transpose(-1, dim).reshape(batch_size, -1, size)
|
||||
grad_cov.mul_(beta2).add_(torch.matmul(this_g.transpose(1,2), this_g))
|
||||
continue # e.g. size==1, size == numel, size > max_block_size
|
||||
grad_cov = state[name] # shaped: (batch_size, num_blocks, block_size, block_size)
|
||||
(batch_size, num_blocks, block_size, block_size) = grad_cov.shape
|
||||
|
||||
g = grad.transpose(-1, dim)
|
||||
# if grad were of shape (batch_size, x, size, y, z),
|
||||
# and g of shape (batch_size, x, y, z, size),
|
||||
# after the next line g will be of shape
|
||||
# (batch_size, x, y, z, num_blocks, block_size)
|
||||
g = g.reshape(batch_size, *g.shape[1:-1],
|
||||
num_blocks, block_size)
|
||||
g = _move_dim(g, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
|
||||
g = g.reshape(batch_size, num_blocks, -1, block_size)
|
||||
|
||||
# this_grad_cov: (batch_size, num_blocks, block_size, block_size)
|
||||
this_grad_cov = torch.matmul(g.transpose(-2, -1), g)
|
||||
|
||||
grad_cov.mul_(beta2).add_(this_grad_cov, alpha=(1-beta2))
|
||||
|
||||
def _step(self,
|
||||
group: dict,
|
||||
@ -771,7 +804,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
lr = group["lr"]
|
||||
beta1, beta2 = group["betas"]
|
||||
eps = group["eps"]
|
||||
grad_cov_period = group["grad_cov_period"]
|
||||
step = state["step"]
|
||||
|
||||
grad = self._project(grad, state, forward=True)
|
||||
@ -843,55 +875,53 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
p.add_(delta)
|
||||
|
||||
def _project(self,
|
||||
x: Tensor,
|
||||
M: Tensor,
|
||||
state: dict,
|
||||
forward: bool) -> Tensor:
|
||||
"""
|
||||
Multiply a tensor x by proj_{dim} on each of its dimensions for which we have
|
||||
a projection Q.
|
||||
If forward == True, converts x from canonical to preconditioned/diagonalized
|
||||
co-ordinates; if forward == False, does the reverse.
|
||||
Multiply a tensor M by a projection Q on each of its dimensions (for which we have
|
||||
such a projection)
|
||||
|
||||
Args:
|
||||
x: The tensor to project, e.g. a gradient or a parameter change.
|
||||
M: The tensor to project, e.g. a gradient or a parameter change.
|
||||
state: dict to look up the projections
|
||||
forward: if True, go in the forward direction (from canonical to diagnonalized
|
||||
co-ordinates); if False, the reverse. They differ by a transpose,
|
||||
not an inverse, and do not make a round trip.
|
||||
"""
|
||||
numel = x.numel()
|
||||
for dim in range(1, x.ndim): # dim 0 is batch dim
|
||||
numel = M.numel()
|
||||
|
||||
for dim in range(1, M.ndim): # dim 0 is batch dim (batch of parameter
|
||||
# tensors of same shape)
|
||||
try:
|
||||
Q = state[f"Q_{dim}"] # shape: (batch_size, size, size)
|
||||
Q = state[f"Q_{dim}"] # shape: (batch_size, num_blocks, block_size, block_size)
|
||||
except KeyError:
|
||||
continue # no projection for this dim
|
||||
|
||||
fullcov = (Q.ndim == 3) # (batch_index, diagonalized_index, canonical_index), else
|
||||
# (batch_index, canonical_index).
|
||||
(batch_size, num_blocks, block_size, block_size) = Q.shape
|
||||
size = M.shape[dim] # == num_blocks * block_size
|
||||
|
||||
if fullcov:
|
||||
if forward:
|
||||
# Q is indexed [batch_index, diagonalized_index, canonical_index]; in the forward
|
||||
# direction we want to change canonical to diagonalized index, so have
|
||||
# to transpose.
|
||||
Q = Q.transpose(1, 2)
|
||||
# TODO: could possibly somehow force the output memory format to be
|
||||
# unchanged.
|
||||
while Q.ndim < x.ndim:
|
||||
Q = Q.unsqueeze(1)
|
||||
# now x might have shape (batch_size, 3, 4, 5, 6)
|
||||
# and Q might have shape (batch_size, 1, 1, 6, 6)
|
||||
# ... and the output will still have shape (batch_size, 3, 4, 5, 6)
|
||||
x = x.transpose(-1, dim)
|
||||
x = torch.matmul(x, Q)
|
||||
x = x.transpose(-1, dim)
|
||||
else:
|
||||
shape = [1] * x.ndim
|
||||
shape[0] = x.shape[0] # batch_size
|
||||
shape[dim] = x.shape[dim] # size
|
||||
Q = Q.reshape(shape) # e.g. (batch_size, 1, size, 1, 1)
|
||||
x = x * Q
|
||||
return x
|
||||
# Q is indexed
|
||||
# [batch_idx, block_idx, diagonalized_idx, canonical_idx]; in
|
||||
# the forward direction we want to change canonical to
|
||||
# diagonalized index, so we transpose.
|
||||
Q = Q.transpose(2, 3)
|
||||
|
||||
# assume M currently has shape (batch_size, x, size, y, z), and dim==2.
|
||||
M = M.transpose(-1, dim) # (batch_size, x, y, z, size)
|
||||
M = M.reshape(batch_size, *M.shape[1:-1],
|
||||
num_blocks, block_size) # (batch_size, x, y, z, num_blocks, block_size)
|
||||
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
|
||||
|
||||
while Q.ndim < M.ndim:
|
||||
Q = Q.unsqueeze(2)
|
||||
# now Q has shape (batch_size, num_blocks, 1, 1, block_size, block_size)
|
||||
M = torch.matmul(M, Q) # (batch_size, num_blocks, x, y, z, block_size)
|
||||
M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size)
|
||||
M = M.reshape(*M.shape[:-2], size) # (batch_size, x, y, z, size)
|
||||
M = M.transpose(-1, dim) # (batch_size, x, size, y, z)
|
||||
return M
|
||||
|
||||
def _smooth_param_rms(self,
|
||||
group: dict,
|
||||
@ -941,6 +971,31 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
return ans
|
||||
|
||||
|
||||
def _move_dim(x: Tensor, orig_dim: int, new_dim: int) -> Tensor:
|
||||
"""
|
||||
Moves the position of one dimension in a Tensor, while keeping
|
||||
the remaining dims in the same order. E.g. if x has
|
||||
shape [10, 1, 2, 3], _move_dim(x, 3, 1) will have shape
|
||||
[10, 3, 1, 2]. Returns the reshaped tensor which will be
|
||||
a view of the original Tensor.
|
||||
Args:
|
||||
x: Tensor to reshape
|
||||
orig_dim: Original dimension to move, with -x.ndim <= orig_dim < x.ndim
|
||||
new_dim: New position of the original dimension, with
|
||||
-x.ndim <= new_dim < x.ndim
|
||||
Returns: a permuted view of x
|
||||
"""
|
||||
if orig_dim < 0:
|
||||
orig_dim += x.ndim
|
||||
if new_dim < 0:
|
||||
new_dim += x.ndim
|
||||
dims = list(range(x.ndim))
|
||||
dims[orig_dim] = -1
|
||||
dims.remove(-1)
|
||||
dims.insert(new_dim, orig_dim)
|
||||
return x.permute(dims)
|
||||
|
||||
|
||||
def _diag(x: Tensor):
|
||||
"""
|
||||
like torch diag(), but supports batch dim, i.e. input of shape (B, M, M) returns
|
||||
@ -951,6 +1006,11 @@ def _diag(x: Tensor):
|
||||
assert M == M2
|
||||
stride = x.stride()
|
||||
return x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2])).contiguous()
|
||||
elif x.ndim == 4:
|
||||
(B, C, M, M2) = x.shape
|
||||
assert M == M2
|
||||
stride = x.stride()
|
||||
return x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3])).contiguous()
|
||||
else:
|
||||
return x.diag()
|
||||
|
||||
@ -1676,7 +1736,7 @@ def _test_eve_cain():
|
||||
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
|
||||
elif iter == 1: optim = Cain(m.parameters(), lr=0.03)
|
||||
elif iter == 2: optim = PrAdam(m.parameters(), lr=0.03)
|
||||
elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_fullcov_size=150)
|
||||
elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_block_size=100)
|
||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
||||
|
||||
start = timeit.default_timer()
|
||||
|
Loading…
x
Reference in New Issue
Block a user