Update parameter dependent part of cov more slowly, plus bug fix.
This commit is contained in:
parent
198cf2635c
commit
8db3b48edb
@ -116,9 +116,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
param covariance equals the dimension of the covaraince matrix.
|
param covariance equals the dimension of the covaraince matrix.
|
||||||
param_rms_smooth{0,1} determine the smoothing proportions for other
|
param_rms_smooth{0,1} determine the smoothing proportions for other
|
||||||
conditions.
|
conditions.
|
||||||
param_cov_freshness: Constant that determines how "recent" the parameter covariance
|
|
||||||
matrix is. 1.0 corresponds to flat average over time; 2.0
|
|
||||||
would mean weighting with a quadratic function, etc.
|
|
||||||
eps: An epsilon to prevent division by zero
|
eps: An epsilon to prevent division by zero
|
||||||
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
|
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
|
||||||
learning the scale on the parameters (we'll keep it >= this size)
|
learning the scale on the parameters (we'll keep it >= this size)
|
||||||
@ -133,8 +130,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
later in training. One step of updating the learning rate matrices
|
later in training. One step of updating the learning rate matrices
|
||||||
can take as long as over 50 minibatches, because SVD on GPU is slow.
|
can take as long as over 50 minibatches, because SVD on GPU is slow.
|
||||||
** This is important for the speed/optimizaton tradeoff. **
|
** This is important for the speed/optimizaton tradeoff. **
|
||||||
param_cov_period: The periodicity, in steps, with which we update the parameter covariance
|
|
||||||
stats.
|
|
||||||
max_block_size: The maximum block size in block-diagonal co-ordinate transformations.
|
max_block_size: The maximum block size in block-diagonal co-ordinate transformations.
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -286,6 +281,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# "zero_step" being a member of state is the sign that this parameter has
|
# "zero_step" being a member of state is the sign that this parameter has
|
||||||
# at least one dim that has a projection.
|
# at least one dim that has a projection.
|
||||||
state["zero_step"] = 0
|
state["zero_step"] = 0
|
||||||
|
# last_param_scale_update records the last time we updated the part of the learning rate
|
||||||
|
# matrices that relates to the parameter covariance; we avoid doing this too often
|
||||||
|
# as it doesn't change very rapidly and the SVD is quite slow.
|
||||||
|
state["last_param_scale_update"] = -1
|
||||||
|
|
||||||
for dim in range(1, p.ndim):
|
for dim in range(1, p.ndim):
|
||||||
size = p.shape[dim]
|
size = p.shape[dim]
|
||||||
@ -301,9 +300,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
Q = torch.eye(block_size, block_size, **kwargs).unsqueeze(0).unsqueeze(0).expand(
|
Q = torch.eye(block_size, block_size, **kwargs).unsqueeze(0).unsqueeze(0).expand(
|
||||||
batch_size, num_blocks, block_size, block_size).contiguous()
|
batch_size, num_blocks, block_size, block_size).contiguous()
|
||||||
state[f"Q_{dim}"] = Q
|
state[f"Q_{dim}"] = Q
|
||||||
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
|
# param_cov_{dim} is the averaged-over-time covariance of parameters on this dimension, treating
|
||||||
# all other dims as a batch axis.
|
# all other dims as a batch axis. Also initialize as identity.
|
||||||
state[f"param_cov_{dim}"] = torch.zeros_like(Q)
|
state[f"param_cov_{dim}"] = Q.clone()
|
||||||
|
|
||||||
# grad_cov_{dim} is the covariance of gradients on this axis (without
|
# grad_cov_{dim} is the covariance of gradients on this axis (without
|
||||||
# any co-ordinate changes), treating all other axes as as a batch axis.
|
# any co-ordinate changes), treating all other axes as as a batch axis.
|
||||||
@ -318,6 +317,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
state[f"grad_cov_{dim}"] = torch.zeros_like(Q)
|
state[f"grad_cov_{dim}"] = torch.zeros_like(Q)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _get_block_size(self, size: int, max_block_size: int) -> Tuple[int,int]:
|
def _get_block_size(self, size: int, max_block_size: int) -> Tuple[int,int]:
|
||||||
"""
|
"""
|
||||||
Returns information about the block size for a block-diagonal structure
|
Returns information about the block size for a block-diagonal structure
|
||||||
@ -387,10 +387,14 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# Updates delta.
|
# Updates delta.
|
||||||
self._step_scalar(group, p, state)
|
self._step_scalar(group, p, state)
|
||||||
else:
|
else:
|
||||||
if step % param_cov_period == 0:
|
|
||||||
self._update_param_cov(group, p, state)
|
|
||||||
if self._is_lr_update_step(group, state):
|
if self._is_lr_update_step(group, state):
|
||||||
self._update_lrs(group, p, state)
|
self._update_param_cov(group, p, state)
|
||||||
|
if step > state["last_param_scale_update"] * 1.1 and state["last_param_scale_update"] != state["zero_step"]:
|
||||||
|
# Only update the parameter-dependent part of the learning
|
||||||
|
# rate matrices at most every other time we reach here, and
|
||||||
|
# less frequently than that later in training.
|
||||||
|
self._update_param_scales(group, p, state)
|
||||||
|
self._diagonalize_grad_cov(group, p, state)
|
||||||
self._zero_exp_avg_sq(state)
|
self._zero_exp_avg_sq(state)
|
||||||
if step % grad_cov_period == 0:
|
if step % grad_cov_period == 0:
|
||||||
self._update_grad_cov(group, p, state)
|
self._update_grad_cov(group, p, state)
|
||||||
@ -467,10 +471,19 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
"""
|
"""
|
||||||
eps = group["eps"]
|
eps = group["eps"]
|
||||||
param_cov_period = group["param_cov_period"]
|
param_cov_period = group["param_cov_period"]
|
||||||
|
|
||||||
|
# zero_step is always the last time we called _update_param_cov.
|
||||||
|
# Our aim is to compute the parameter covariance averaged over all time
|
||||||
|
# (in steps) from the start, so "this_weight" that we give to this
|
||||||
|
# new covariance is proportional to the interval of time it represents,
|
||||||
|
# i.e. the interval from zero_step until step, while the existing "param_cov"
|
||||||
|
# represents the interval from step until zero_step.
|
||||||
|
# The min of 0.5 is a special case to ensure that the "torch.eye" we initialized
|
||||||
|
# param_cov with gets some weight on the 1st time we call this.
|
||||||
|
zero_step = state["zero_step"]
|
||||||
step = state["step"]
|
step = state["step"]
|
||||||
|
|
||||||
this_weight = (group["param_cov_freshness"] * param_cov_period /
|
this_weight = min(0.5, (step - zero_step) / step)
|
||||||
(step + param_cov_period))
|
|
||||||
|
|
||||||
batch_size = p.shape[0]
|
batch_size = p.shape[0]
|
||||||
numel = p.numel() // batch_size
|
numel = p.numel() // batch_size
|
||||||
@ -495,9 +508,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# normalize scale of this param_cov, in case parameter scale
|
# normalize scale of this param_cov, in case parameter scale
|
||||||
# changes significantly during training, which would cause some
|
# changes significantly during training, which would cause some
|
||||||
# parts of the training timeline to be more highly weighted.
|
# parts of the training timeline to be more highly weighted.
|
||||||
# shape of this_param_cov
|
# shape of expression on r.h.s of "/=" has shape (batch_size, 1, 1, 1)
|
||||||
# expression after /= has shape (batch_size, num_blocks, 1, 1)
|
this_param_cov /= _mean(_diag(this_param_cov),
|
||||||
this_param_cov /= _diag(this_param_cov).mean(dim=[0,1], keepdim=True).unsqueeze(-1) + eps
|
exclude_dims=[0], keepdim=True).unsqueeze(-1) + eps
|
||||||
|
|
||||||
param_cov.mul_(1-this_weight).add_(this_param_cov,
|
param_cov.mul_(1-this_weight).add_(this_param_cov,
|
||||||
alpha=this_weight)
|
alpha=this_weight)
|
||||||
@ -516,7 +529,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
state: dict) -> bool:
|
state: dict) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns True if on this step we need to update the learning-rate matrices for this tensor
|
Returns True if on this step we need to update the learning-rate matrices for this tensor
|
||||||
and False if not. The periodicity with which we update them increases from
|
and False if not (note: this may just mean we update the gradient-dependent part).
|
||||||
|
The periodicity with which we update them increases from
|
||||||
(by default) 200 at the start of training to 2000 later on.
|
(by default) 200 at the start of training to 2000 later on.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
@ -534,12 +548,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
return (step >= zero_step + cur_update_period)
|
return (step >= zero_step + cur_update_period)
|
||||||
|
|
||||||
|
|
||||||
def _update_lrs(self,
|
def _update_param_scales(self,
|
||||||
group: dict,
|
group: dict,
|
||||||
p: Tensor,
|
p: Tensor,
|
||||||
state: dict) -> None:
|
state: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Computes learning-rate matrices Q for each dim of this tensor.
|
Computes learning-rate matrices Q for each dim of this tensor: only the part that depends
|
||||||
|
on the parameter covariance, we will later add a rotation that depends on the gradient
|
||||||
|
covariance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group: dict to look up configuration values
|
group: dict to look up configuration values
|
||||||
p: parameter matrix that we are updating. The learning rate matrices
|
p: parameter matrix that we are updating. The learning rate matrices
|
||||||
@ -704,8 +721,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# want to multiply on the diagonalized co-ordinate.
|
# want to multiply on the diagonalized co-ordinate.
|
||||||
# else: Q is indexed [batch_index, canonical_coordinate].
|
# else: Q is indexed [batch_index, canonical_coordinate].
|
||||||
state[f"Q_{dim}"] *= scale
|
state[f"Q_{dim}"] *= scale
|
||||||
|
state["last_param_scale_update"] = state["step"]
|
||||||
self._diagonalize_grad_cov(group, p, state)
|
|
||||||
|
|
||||||
def _diagonalize_grad_cov(self,
|
def _diagonalize_grad_cov(self,
|
||||||
group: dict,
|
group: dict,
|
||||||
@ -1033,7 +1049,8 @@ def _move_dim(x: Tensor, orig_dim: int, new_dim: int) -> Tensor:
|
|||||||
def _diag(x: Tensor):
|
def _diag(x: Tensor):
|
||||||
"""
|
"""
|
||||||
like torch diag(), but supports batch dim, i.e. input of shape (B, M, M) returns
|
like torch diag(), but supports batch dim, i.e. input of shape (B, M, M) returns
|
||||||
output of shape (B, M)
|
output of shape (B, M), or input of shape (A, B, M, M) returns output of shape
|
||||||
|
(A, B, M)
|
||||||
"""
|
"""
|
||||||
if x.ndim == 3:
|
if x.ndim == 3:
|
||||||
(B, M, M2) = x.shape
|
(B, M, M2) = x.shape
|
||||||
@ -1793,7 +1810,7 @@ def _test_eve_cain():
|
|||||||
if epoch == 0 and n == 0:
|
if epoch == 0 and n == 0:
|
||||||
avg_loss = loss.item()
|
avg_loss = loss.item()
|
||||||
else:
|
else:
|
||||||
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
|
avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
|
||||||
if n == 0 and epoch % 5 == 0:
|
if n == 0 and epoch % 5 == 0:
|
||||||
#norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
#norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
||||||
#norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
#norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user