diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 2886b17d3..974a61353 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -116,9 +116,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of param covariance equals the dimension of the covaraince matrix. param_rms_smooth{0,1} determine the smoothing proportions for other conditions. - param_cov_freshness: Constant that determines how "recent" the parameter covariance - matrix is. 1.0 corresponds to flat average over time; 2.0 - would mean weighting with a quadratic function, etc. eps: An epsilon to prevent division by zero param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of learning the scale on the parameters (we'll keep it >= this size) @@ -133,8 +130,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of later in training. One step of updating the learning rate matrices can take as long as over 50 minibatches, because SVD on GPU is slow. ** This is important for the speed/optimizaton tradeoff. ** - param_cov_period: The periodicity, in steps, with which we update the parameter covariance - stats. max_block_size: The maximum block size in block-diagonal co-ordinate transformations. """ def __init__( @@ -286,6 +281,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # "zero_step" being a member of state is the sign that this parameter has # at least one dim that has a projection. state["zero_step"] = 0 + # last_param_scale_update records the last time we updated the part of the learning rate + # matrices that relates to the parameter covariance; we avoid doing this too often + # as it doesn't change very rapidly and the SVD is quite slow. + state["last_param_scale_update"] = -1 for dim in range(1, p.ndim): size = p.shape[dim] @@ -301,9 +300,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of Q = torch.eye(block_size, block_size, **kwargs).unsqueeze(0).unsqueeze(0).expand( batch_size, num_blocks, block_size, block_size).contiguous() state[f"Q_{dim}"] = Q - # param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating - # all other dims as a batch axis. - state[f"param_cov_{dim}"] = torch.zeros_like(Q) + # param_cov_{dim} is the averaged-over-time covariance of parameters on this dimension, treating + # all other dims as a batch axis. Also initialize as identity. + state[f"param_cov_{dim}"] = Q.clone() # grad_cov_{dim} is the covariance of gradients on this axis (without # any co-ordinate changes), treating all other axes as as a batch axis. @@ -318,6 +317,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of state[f"grad_cov_{dim}"] = torch.zeros_like(Q) + def _get_block_size(self, size: int, max_block_size: int) -> Tuple[int,int]: """ Returns information about the block size for a block-diagonal structure @@ -387,10 +387,14 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # Updates delta. self._step_scalar(group, p, state) else: - if step % param_cov_period == 0: - self._update_param_cov(group, p, state) if self._is_lr_update_step(group, state): - self._update_lrs(group, p, state) + self._update_param_cov(group, p, state) + if step > state["last_param_scale_update"] * 1.1 and state["last_param_scale_update"] != state["zero_step"]: + # Only update the parameter-dependent part of the learning + # rate matrices at most every other time we reach here, and + # less frequently than that later in training. + self._update_param_scales(group, p, state) + self._diagonalize_grad_cov(group, p, state) self._zero_exp_avg_sq(state) if step % grad_cov_period == 0: self._update_grad_cov(group, p, state) @@ -467,10 +471,19 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of """ eps = group["eps"] param_cov_period = group["param_cov_period"] + + # zero_step is always the last time we called _update_param_cov. + # Our aim is to compute the parameter covariance averaged over all time + # (in steps) from the start, so "this_weight" that we give to this + # new covariance is proportional to the interval of time it represents, + # i.e. the interval from zero_step until step, while the existing "param_cov" + # represents the interval from step until zero_step. + # The min of 0.5 is a special case to ensure that the "torch.eye" we initialized + # param_cov with gets some weight on the 1st time we call this. + zero_step = state["zero_step"] step = state["step"] - this_weight = (group["param_cov_freshness"] * param_cov_period / - (step + param_cov_period)) + this_weight = min(0.5, (step - zero_step) / step) batch_size = p.shape[0] numel = p.numel() // batch_size @@ -495,9 +508,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # normalize scale of this param_cov, in case parameter scale # changes significantly during training, which would cause some # parts of the training timeline to be more highly weighted. - # shape of this_param_cov - # expression after /= has shape (batch_size, num_blocks, 1, 1) - this_param_cov /= _diag(this_param_cov).mean(dim=[0,1], keepdim=True).unsqueeze(-1) + eps + # shape of expression on r.h.s of "/=" has shape (batch_size, 1, 1, 1) + this_param_cov /= _mean(_diag(this_param_cov), + exclude_dims=[0], keepdim=True).unsqueeze(-1) + eps param_cov.mul_(1-this_weight).add_(this_param_cov, alpha=this_weight) @@ -516,7 +529,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of state: dict) -> bool: """ Returns True if on this step we need to update the learning-rate matrices for this tensor - and False if not. The periodicity with which we update them increases from + and False if not (note: this may just mean we update the gradient-dependent part). + The periodicity with which we update them increases from (by default) 200 at the start of training to 2000 later on. """ try: @@ -534,12 +548,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of return (step >= zero_step + cur_update_period) - def _update_lrs(self, - group: dict, - p: Tensor, - state: dict) -> None: + def _update_param_scales(self, + group: dict, + p: Tensor, + state: dict) -> None: """ - Computes learning-rate matrices Q for each dim of this tensor. + Computes learning-rate matrices Q for each dim of this tensor: only the part that depends + on the parameter covariance, we will later add a rotation that depends on the gradient + covariance. + Args: group: dict to look up configuration values p: parameter matrix that we are updating. The learning rate matrices @@ -704,8 +721,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # want to multiply on the diagonalized co-ordinate. # else: Q is indexed [batch_index, canonical_coordinate]. state[f"Q_{dim}"] *= scale - - self._diagonalize_grad_cov(group, p, state) + state["last_param_scale_update"] = state["step"] def _diagonalize_grad_cov(self, group: dict, @@ -1033,7 +1049,8 @@ def _move_dim(x: Tensor, orig_dim: int, new_dim: int) -> Tensor: def _diag(x: Tensor): """ like torch diag(), but supports batch dim, i.e. input of shape (B, M, M) returns - output of shape (B, M) + output of shape (B, M), or input of shape (A, B, M, M) returns output of shape + (A, B, M) """ if x.ndim == 3: (B, M, M2) = x.shape @@ -1793,7 +1810,7 @@ def _test_eve_cain(): if epoch == 0 and n == 0: avg_loss = loss.item() else: - avg_loss = 0.95 * avg_loss + 0.05 * loss.item() + avg_loss = 0.98 * avg_loss + 0.02 * loss.item() if n == 0 and epoch % 5 == 0: #norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() #norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()