mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Update parameter dependent part of cov more slowly, plus bug fix.
This commit is contained in:
parent
198cf2635c
commit
8db3b48edb
@ -116,9 +116,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
param covariance equals the dimension of the covaraince matrix.
|
||||
param_rms_smooth{0,1} determine the smoothing proportions for other
|
||||
conditions.
|
||||
param_cov_freshness: Constant that determines how "recent" the parameter covariance
|
||||
matrix is. 1.0 corresponds to flat average over time; 2.0
|
||||
would mean weighting with a quadratic function, etc.
|
||||
eps: An epsilon to prevent division by zero
|
||||
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
|
||||
learning the scale on the parameters (we'll keep it >= this size)
|
||||
@ -133,8 +130,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
later in training. One step of updating the learning rate matrices
|
||||
can take as long as over 50 minibatches, because SVD on GPU is slow.
|
||||
** This is important for the speed/optimizaton tradeoff. **
|
||||
param_cov_period: The periodicity, in steps, with which we update the parameter covariance
|
||||
stats.
|
||||
max_block_size: The maximum block size in block-diagonal co-ordinate transformations.
|
||||
"""
|
||||
def __init__(
|
||||
@ -286,6 +281,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# "zero_step" being a member of state is the sign that this parameter has
|
||||
# at least one dim that has a projection.
|
||||
state["zero_step"] = 0
|
||||
# last_param_scale_update records the last time we updated the part of the learning rate
|
||||
# matrices that relates to the parameter covariance; we avoid doing this too often
|
||||
# as it doesn't change very rapidly and the SVD is quite slow.
|
||||
state["last_param_scale_update"] = -1
|
||||
|
||||
for dim in range(1, p.ndim):
|
||||
size = p.shape[dim]
|
||||
@ -301,9 +300,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
Q = torch.eye(block_size, block_size, **kwargs).unsqueeze(0).unsqueeze(0).expand(
|
||||
batch_size, num_blocks, block_size, block_size).contiguous()
|
||||
state[f"Q_{dim}"] = Q
|
||||
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
|
||||
# all other dims as a batch axis.
|
||||
state[f"param_cov_{dim}"] = torch.zeros_like(Q)
|
||||
# param_cov_{dim} is the averaged-over-time covariance of parameters on this dimension, treating
|
||||
# all other dims as a batch axis. Also initialize as identity.
|
||||
state[f"param_cov_{dim}"] = Q.clone()
|
||||
|
||||
# grad_cov_{dim} is the covariance of gradients on this axis (without
|
||||
# any co-ordinate changes), treating all other axes as as a batch axis.
|
||||
@ -318,6 +317,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
state[f"grad_cov_{dim}"] = torch.zeros_like(Q)
|
||||
|
||||
|
||||
|
||||
def _get_block_size(self, size: int, max_block_size: int) -> Tuple[int,int]:
|
||||
"""
|
||||
Returns information about the block size for a block-diagonal structure
|
||||
@ -387,10 +387,14 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# Updates delta.
|
||||
self._step_scalar(group, p, state)
|
||||
else:
|
||||
if step % param_cov_period == 0:
|
||||
self._update_param_cov(group, p, state)
|
||||
if self._is_lr_update_step(group, state):
|
||||
self._update_lrs(group, p, state)
|
||||
self._update_param_cov(group, p, state)
|
||||
if step > state["last_param_scale_update"] * 1.1 and state["last_param_scale_update"] != state["zero_step"]:
|
||||
# Only update the parameter-dependent part of the learning
|
||||
# rate matrices at most every other time we reach here, and
|
||||
# less frequently than that later in training.
|
||||
self._update_param_scales(group, p, state)
|
||||
self._diagonalize_grad_cov(group, p, state)
|
||||
self._zero_exp_avg_sq(state)
|
||||
if step % grad_cov_period == 0:
|
||||
self._update_grad_cov(group, p, state)
|
||||
@ -467,10 +471,19 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
"""
|
||||
eps = group["eps"]
|
||||
param_cov_period = group["param_cov_period"]
|
||||
|
||||
# zero_step is always the last time we called _update_param_cov.
|
||||
# Our aim is to compute the parameter covariance averaged over all time
|
||||
# (in steps) from the start, so "this_weight" that we give to this
|
||||
# new covariance is proportional to the interval of time it represents,
|
||||
# i.e. the interval from zero_step until step, while the existing "param_cov"
|
||||
# represents the interval from step until zero_step.
|
||||
# The min of 0.5 is a special case to ensure that the "torch.eye" we initialized
|
||||
# param_cov with gets some weight on the 1st time we call this.
|
||||
zero_step = state["zero_step"]
|
||||
step = state["step"]
|
||||
|
||||
this_weight = (group["param_cov_freshness"] * param_cov_period /
|
||||
(step + param_cov_period))
|
||||
this_weight = min(0.5, (step - zero_step) / step)
|
||||
|
||||
batch_size = p.shape[0]
|
||||
numel = p.numel() // batch_size
|
||||
@ -495,9 +508,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# normalize scale of this param_cov, in case parameter scale
|
||||
# changes significantly during training, which would cause some
|
||||
# parts of the training timeline to be more highly weighted.
|
||||
# shape of this_param_cov
|
||||
# expression after /= has shape (batch_size, num_blocks, 1, 1)
|
||||
this_param_cov /= _diag(this_param_cov).mean(dim=[0,1], keepdim=True).unsqueeze(-1) + eps
|
||||
# shape of expression on r.h.s of "/=" has shape (batch_size, 1, 1, 1)
|
||||
this_param_cov /= _mean(_diag(this_param_cov),
|
||||
exclude_dims=[0], keepdim=True).unsqueeze(-1) + eps
|
||||
|
||||
param_cov.mul_(1-this_weight).add_(this_param_cov,
|
||||
alpha=this_weight)
|
||||
@ -516,7 +529,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
state: dict) -> bool:
|
||||
"""
|
||||
Returns True if on this step we need to update the learning-rate matrices for this tensor
|
||||
and False if not. The periodicity with which we update them increases from
|
||||
and False if not (note: this may just mean we update the gradient-dependent part).
|
||||
The periodicity with which we update them increases from
|
||||
(by default) 200 at the start of training to 2000 later on.
|
||||
"""
|
||||
try:
|
||||
@ -534,12 +548,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
return (step >= zero_step + cur_update_period)
|
||||
|
||||
|
||||
def _update_lrs(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict) -> None:
|
||||
def _update_param_scales(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict) -> None:
|
||||
"""
|
||||
Computes learning-rate matrices Q for each dim of this tensor.
|
||||
Computes learning-rate matrices Q for each dim of this tensor: only the part that depends
|
||||
on the parameter covariance, we will later add a rotation that depends on the gradient
|
||||
covariance.
|
||||
|
||||
Args:
|
||||
group: dict to look up configuration values
|
||||
p: parameter matrix that we are updating. The learning rate matrices
|
||||
@ -704,8 +721,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# want to multiply on the diagonalized co-ordinate.
|
||||
# else: Q is indexed [batch_index, canonical_coordinate].
|
||||
state[f"Q_{dim}"] *= scale
|
||||
|
||||
self._diagonalize_grad_cov(group, p, state)
|
||||
state["last_param_scale_update"] = state["step"]
|
||||
|
||||
def _diagonalize_grad_cov(self,
|
||||
group: dict,
|
||||
@ -1033,7 +1049,8 @@ def _move_dim(x: Tensor, orig_dim: int, new_dim: int) -> Tensor:
|
||||
def _diag(x: Tensor):
|
||||
"""
|
||||
like torch diag(), but supports batch dim, i.e. input of shape (B, M, M) returns
|
||||
output of shape (B, M)
|
||||
output of shape (B, M), or input of shape (A, B, M, M) returns output of shape
|
||||
(A, B, M)
|
||||
"""
|
||||
if x.ndim == 3:
|
||||
(B, M, M2) = x.shape
|
||||
@ -1793,7 +1810,7 @@ def _test_eve_cain():
|
||||
if epoch == 0 and n == 0:
|
||||
avg_loss = loss.item()
|
||||
else:
|
||||
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
|
||||
avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
|
||||
if n == 0 and epoch % 5 == 0:
|
||||
#norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
||||
#norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user