From 4da4e69fba5be901edac3b11e9be98612240c757 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 22 Jul 2022 06:37:20 +0800 Subject: [PATCH] Draft of new way of smoothing param_rms, diagonalized by grad --- .../ASR/pruned_transducer_stateless7/optim.py | 102 +++++++++++++++++- 1 file changed, 97 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 15be2ba06..af4b1a1b9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -142,10 +142,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of lr=3e-02, betas=(0.9, 0.98), size_lr_scale=0.1, - param_pow=0.75, + param_pow=1.0, param_rms_smooth0=0.75, param_rms_smooth1=0.25, - max_lr_factor=4.0, + max_lr_factor=10.0, eps=1.0e-08, param_min_rms=1.0e-05, param_max_rms=2.0, @@ -593,10 +593,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of size = p.shape[dim] try: Q = state[f"Q_{dim}"] - param_cov = state[f"param_cov_{dim}"] except KeyError: assert size == 1 or size == numel, size - continue # e.g. size == 1 or size == numel: + continue # e.g. size == 1 or size == numel + + param_cov = self._get_smoothed_param_cov(group, p, state, dim) # param_cov has the same shape as Q (batch_size, num_blocks, block_size, block_size) = Q.shape @@ -683,7 +684,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # rms shape: (batch_size, 1, size, 1, 1) rms = _mean(cur_p**2, exclude_dims=[0,dim], keepdim=True).sqrt() rank = numel // size - smoothed_rms = self._smooth_param_rms(group, rms, rank) + # we did other kinds of smoothing in _get_smoothed_param_cov + #smoothed_rms = self._smooth_param_rms(group, rms, rank) + smoothed_rms = rms ** group["param_pow"] cur_scales[dim] = smoothed_rms cur_p /= smoothed_rms # normalize/"whiten" cur_p on this dim.. @@ -727,6 +730,95 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of state[f"Q_{dim}"] *= scale state["last_param_scale_update"] = state["step"] + + def _get_smoothed_param_cov(self, + group: dict, + p: Tensor, + state: dict, + dim: int) -> Tensor: + """ + This function returns a modified/smoothed version of the parameter covariance + state[f"param_cov_{dim}"], which is an estimate of the covariance of the parameter + p, averaged over time, and taken over dimension `dim` of the tensor. + + The smoothing done here limits the extend to which the parameter covariance + can be strongly "off-diagonal" with respect to the gradient covariance. That is: + if the parameter covariance is just the gradient covariance to some power, this + function does no smoothing; but if it is highly off-diagonal we do more smoothing. + """ + param_cov = state[f"param_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size) + grad_cov = state[f"grad_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size) + (batch_size, num_blocks, block_size, block_size) = param_cov.shape + U_g, _, _ = _svd(grad_cov) # U_g diagonalizes grad_cov, in the sense that U_g^T grad_cov U_g is diagonal. + + # param_cov_proj is param_cov in a different orthonormal basis, that diagonalizes + # grad_cov. + param_cov_proj = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g)) + + # param_cov_eps is probably not critical, I don't expect to see super + # small values. apply as floor in case roundoff causes negative values. + param_cov_eps = 1.0e-05 + param_rms = _diag(param_cov_proj).clamp_(min=param_cov_eps).sqrt() + param_cov_inv_scale = param_rms.unsqueeze(-1) * param_rms.unsqueeze(-2) + + # param_cov_norm should have diagonal values close to 1.0 (only not + # exactly 1.0 due to param_cov_eps and roundoff) + param_cov_norm = param_cov_proj / param_cov_inv_scale + + # OK, this is where we do smoothing. + # decompose param_cov_norm, which is symmetric, as U_p S U_p^T + U_p, S, _ = _svd(param_cov_norm) + + + + residual_rms = S.sqrt() + # + relative_rms_pow = 0.7 + relative_rms_max = 4.0 + + residual_rms = residual_rms ** relative_rms_pow + residual_rms /= _mean(residual_rms, exclude_dims=[0], keepdim=True) + + if True: + # smooth according to the rank of the observation.. + size = p.shape[dim] + rank = p.numel() // (size * batch_size) + smooth0 = group["param_rms_smooth0"] + smooth1 = group["param_rms_smooth1"] + # want expr to be of the form: smooth = alpha * size / (beta*rank + size) + # from rank==0, we get smooth0 = alpha * size/size, so alpha = smooth0. + # from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta), + # so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1 + smooth = smooth0 * size / ((smooth0/smooth1 - 1) * rank + size) + + mean = _mean(residual_rms, exclude_dims=[0], keepdim=True) + residual_rms += group["eps"] + smooth * mean + residual_rms = residual_rms / _mean(residual_rms, exclude_dims=[0], keepdim=True) + + + # apply the maximum via a softmin function, softmin(x,y) = 1/(1/x + 1/y) + residual_rms = 1. / (1. / residual_rms + 1. / relative_rms_max) + + if random.random() < 0.1: + skip = 10 if S.shape[-1] > 40 else 1 + logging.info(f"Smoothed param_rms from {S.sqrt()[0,0,::skip]} to {residual_rms[0,0,::skip]}, param_rms={param_rms[0,0,::skip]}") + + # U shape: (batch_size, num_blocks, block_size, block_size), + # interpreted as + # residual_rms shape: (batch_size, num_blocks, block_size). + # so in terms of matrix multiplication, we are computing X_p = matmul(U_p, residual_rms.diag()) + X_p = U_p * residual_rms.unsqueeze(-2) + param_cov_norm_smoothed = torch.matmul(X_p, X_p.transpose(2, 3)) + + # Undo the scaling by the diagonal of param_cov + param_cov_proj_smoothed = param_cov_norm_smoothed * param_cov_inv_scale + + # Undo the projection by U. + param_cov_smoothed = torch.matmul(U_g, torch.matmul(param_cov_proj_smoothed, + U_g.transpose(2, 3))) + return param_cov_smoothed + + def _diagonalize_grad_cov(self, group: dict, p: Tensor,