From dd10eb140ff816e55041386672855728b06b2c3f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 23 Jul 2022 06:32:56 +0800 Subject: [PATCH] First version after refactorization and changing the math, where optim.py runs --- .../ASR/pruned_transducer_stateless7/optim.py | 408 +++++++++++++----- 1 file changed, 295 insertions(+), 113 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index af4b1a1b9..902a7d21c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -142,10 +142,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of lr=3e-02, betas=(0.9, 0.98), size_lr_scale=0.1, - param_pow=1.0, + min_lr_factor=(0.05, 0.05, 0.05), + max_lr_factor=(10.0, 10.0, 10.0), param_rms_smooth0=0.75, param_rms_smooth1=0.25, - max_lr_factor=10.0, eps=1.0e-08, param_min_rms=1.0e-05, param_max_rms=2.0, @@ -153,7 +153,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of size_update_period=4, lr_update_period=(200, 1000), grad_cov_period=3, - param_cov_period=100, max_block_size=1024, ): @@ -162,10 +161,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of defaults = dict( lr=lr, size_lr_scale=size_lr_scale, - param_pow=param_pow, + min_lr_factor=min_lr_factor, + max_lr_factor=max_lr_factor, param_rms_smooth0=param_rms_smooth0, param_rms_smooth1=param_rms_smooth1, - max_lr_factor=max_lr_factor, betas=betas, eps=eps, param_min_rms=param_min_rms, @@ -174,7 +173,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of size_update_period=size_update_period, lr_update_period=lr_update_period, grad_cov_period=grad_cov_period, - param_cov_period=param_cov_period, max_block_size=max_block_size, ) @@ -283,7 +281,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of return # "zero_step" being a member of state is the sign that this parameter has - # at least one dim that has a projection. + # at least one dim that has a projection. It also records the most recent + # step on which we zeroed state["exp_avg_sq"] state["zero_step"] = 0 # last_param_scale_update records the last time we updated the part of the learning rate # matrices that relates to the parameter covariance; we avoid doing this too often @@ -359,7 +358,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of lr = group["lr"] size_update_period = group["size_update_period"] grad_cov_period = group["grad_cov_period"] - param_cov_period = group["param_cov_period"] eps = group["eps"] beta1 = group["betas"][0] @@ -393,12 +391,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of else: if self._is_lr_update_step(group, state): self._update_param_cov(group, p, state) - if step > state["last_param_scale_update"] * 1.1 and state["last_param_scale_update"] != state["zero_step"]: - # Only update the parameter-dependent part of the learning - # rate matrices at most every other time we reach here, and - # less frequently than that later in training. - self._update_param_scales(group, p, state) - self._diagonalize_grad_cov(group, p, state) + + P_proj = self._compute_bases(group, p.shape, state) + # Only update the parameter-dependent part of the learning + # rate matrices at most every other time we reach here, and + # less frequently than that later in training. + self._update_param_scales(group, p, state, P_proj) + + # We won't be doing this any more. + #self._diagonalize_grad_cov(group, p, state) self._zero_exp_avg_sq(state) if step % grad_cov_period == 0: self._update_grad_cov(group, p, state) @@ -474,7 +475,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of (except batch and trivial and rank-1 dims) """ eps = group["eps"] - param_cov_period = group["param_cov_period"] # zero_step is always the last time we called _update_param_cov. # Our aim is to compute the parameter covariance averaged over all time @@ -555,7 +555,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of def _update_param_scales(self, group: dict, p: Tensor, - state: dict) -> None: + state: dict, + P_proj: List[Optional[Tensor]]) -> None: """ Computes learning-rate matrices Q for each dim of this tensor: only the part that depends on the parameter covariance, we will later add a rotation that depends on the gradient @@ -567,6 +568,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of are actually factors of p, so p itself will change when we change them. state: state dict for the current parameter + P_proj: a list indexed by dim, containing tensors of shape (batch_size, size) + where size == p.shape[dim], which represent the parameter covariance + projected by state[f"Q_{dim}"] (i.e. by the value of this variable + at entry to this function) """ ndim = p.ndim batch_size = p.shape[0] @@ -597,15 +602,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of assert size == 1 or size == numel, size continue # e.g. size == 1 or size == numel - param_cov = self._get_smoothed_param_cov(group, p, state, dim) - - # param_cov has the same shape as Q (batch_size, num_blocks, block_size, block_size) = Q.shape - # U,S,V have extra leading dimensions, i.e. of shape - # {U,V}.shape == (batch_size, num_blocks, block_size, block_size), - # S.shape == (batch_size, num_blocks, block_size). - U, S, V = _svd(param_cov) - Q[:] = U.transpose(2, 3) M = cur_p.transpose(dim, -1) # if p were of shape (batch_size, x, size, y, z), @@ -615,11 +612,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of num_blocks, block_size) M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size) - while U.ndim < M.ndim: - U = U.unsqueeze(2) - # Now U is of shape (batch_size, num_blocks, 1, 1, block_size, block_size) - - M = torch.matmul(M, U) # (batch_size, num_blocks, x, y, z, block_size) + while Q.ndim < M.ndim: + Q = Q.unsqueeze(2) + # Now Q is of shape (batch_size, num_blocks, 1, 1, block_size, block_size) + # [batch_index, block_index, diagonalized_coordinate, canonical_coordinate], + # so we need to transpose Q as we convert M to the diagonalized co-ordinate. + M = torch.matmul(M, Q.transpose(2, 3)) # (batch_size, num_blocks, x, y, z, block_size) M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size) M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, y, z, size) cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z) @@ -630,7 +628,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of cur_param_var = _mean(cur_p**2, exclude_dims=[0,dim], keepdim=True) # (batch_size, 1, size, 1, 1) if dim==2 - S = S.reshape(cur_param_var.shape) # (batch_size, 1, size, 1, 1) if dim==2 + smoothed_param_var = P_proj[dim] # (batch_size, size) + S = smoothed_param_var.reshape(cur_param_var.shape) # (batch_size, 1, size, 1, 1) if dim==2 # OK, cur_param_var would have the values as S if the variance stats # param_cov_{dim} were accumulated from this exact parameter matrix, @@ -639,7 +638,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # spectrum"). We scale p so that it matches the accumulated stats, # the idea is to ensure it doesn't have any too-small eigenvalues # (where the stats permit). - scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt() + scale = (S / cur_param_var.clamp(min=eps)).sqrt() if random.random() < 0.01: skip = 10 if size < 20 else 1 @@ -684,18 +683,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # rms shape: (batch_size, 1, size, 1, 1) rms = _mean(cur_p**2, exclude_dims=[0,dim], keepdim=True).sqrt() rank = numel // size - # we did other kinds of smoothing in _get_smoothed_param_cov - #smoothed_rms = self._smooth_param_rms(group, rms, rank) - smoothed_rms = rms ** group["param_pow"] - cur_scales[dim] = smoothed_rms - cur_p /= smoothed_rms # normalize/"whiten" cur_p on this dim.. + # TODO: consider more smoothing here??? + cur_scales[dim] = rms + cur_p /= rms if debug: def _summarize(rms): rms = rms[0] # get rid of batch dim by selecting one example rms = rms.flatten() return rms[::10] # subset one every ten items - logging.info(f"i={i} shape={tuple(p.shape)}, dim={dim}, rank={rank}, size={size}, rms={_summarize(rms)}, smoothed_rms={_summarize(smoothed_rms)}") + logging.info(f"i={i} shape={tuple(p.shape)}, dim={dim}, rank={rank}, size={size}, rms={_summarize(rms)}") # Apply the scales in `cur_scales` to Q for each dim; this reflects the # parameter rms values in the parameter-diagonalized space, that we have @@ -731,92 +728,272 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of state["last_param_scale_update"] = state["step"] - def _get_smoothed_param_cov(self, - group: dict, - p: Tensor, - state: dict, - dim: int) -> Tensor: + def _compute_bases(self, + group: dict, + p_shape: torch.Size, + state: dict) -> List[Optional[Tensor]]: + """ + For each tensor dimension that we are preconditioning, this function sets + state[f"Q_{dim}"] to an orthogonal matrix (per batch element and + block); it will be indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate]. + This will be the matrix that diagonalizes Z, + where Z is the symmetric matrix satisfying: + Z G Z = P (eqn:1) + with G being the gradient covariance grad_cov, P being the smoothed version of the parameter + covariance param_cov, and Z the symmetric positive definit learning-rate matrices. + We'll discuss later how we solve for Z. Firstly, because we want + a basis that is as meaningful possible for smoothing P, we will first put P and G in a basis + where G is diagonal. That is, we do an SVD G = U G' U^T, and multiplying (eqn:1) on the + left and right by U^T and U respetively, and inserting some expressions equivalent to I, + we have: + (U^T Z U) (U^T G U) (U^T Z U) = (U^T P U) + or: Z' G' Z' = P' (eqn:2) + where Z' = U^T Z U and P' = U^T P U, and of course G' is diagonal. We are actually going to + be solving (eqn:2), and then computing: + Z = U Z' U^T. + + A solution to (eqn:1) is as follows. We are going to be using a Cholesky-based solution in + favor of one that requires SVD or eigenvalue decomposition, because it is much faster (we first + have to be careful that the input is not close to singular, though). + + So, with the Cholesky decomposition of P' being: + P' = C C^T, + the solution is going to be: + Z' = C (C^T G' C)^{-0.5} C^T (eqn:3) + [[ we can verify that this is a solution by multiplying (eqn:1) by C^{-1} and C^{-T} on the + left and right respectively, giving: + (C^T G' C)^{-0.5} C^T G' C (C^T G' C)^{-0.5} = I + which we can immediately see is the case. ]] + + Args: + group: dict to look up config values + p_shape: the shape of the batch of identical-sized tensors we are optimizing + state: dict to look up optimization state + + Return: + This function returns a list of Tensors P_proj indexed by dim from 0..ndim-1, the tensors being + of shape (batch_size, size), which contain the diagonal of the smoothed parameter covariance + P after projection by the orthogonal matrix state[f"Q_{dim}"] that diagonalizes Z (this is + the space in which we do the actual update). + """ + ndim = len(p_shape) + P_proj = [None] * ndim + for dim in range(1, ndim): + size = p_shape[dim] + try: + Q = state[f"Q_{dim}"] + except KeyError: + assert size == 1 or size == numel, size + continue # e.g. size == 1 or size == numel + + + param_cov = state[f"param_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size) + grad_cov = state[f"grad_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size) + (batch_size, num_blocks, block_size, block_size) = param_cov.shape + # U_g diagonalizes grad_cov, in the sense that U_g^T grad_cov U_g is diagonal; its shape + # is the same as param_cov and grad_cov. + # + # G_prime is the diagonalized grad_cov, G' above, of shape (batch_size, num_blocks, block_size). + U_g, G_prime, _ = _svd(grad_cov) + # P_prime is P' above, which represents param_cov in the basis that diagonalizes G_prime. + # It is not smoothed yet. + P_prime = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g)) + + P_prime = self._smooth_param_cov(group, p_shape, P_prime, G_prime) + + C = P_prime.cholesky() # P_prime = torch.matmul(C, C.transpose(2, 3)) + + # CGC = (C^T G' C) which would be torch.matmul(C.transpose(2, 3), torch.matmul(G, C)) + # if G were formatted as a diagonal matrix, but G is just the diagonal. + CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2), + C) + U, S, _ = _svd(CGC) # Need SVD to compute CGC^{-0.5} + # next we compute (eqn:3). The thing in the parenthesis is, GCC^{-0.5}, + # can be written as U S^{-0.5} U^T, so the whole thing is + # (C U) S^{-0.5} (C U)^T + CU = torch.matmul(C, U) + S_inv_sqrt = 1.0 / S.sqrt() + Z_prime = torch.matmul(CU * S_inv_sqrt.unsqueeze(-2), + CU.transpose(2, 3)) + + if True: + # A check. + # Z_prime is supposed to satisfy: Z_prime G_prime Z_prime = P_prime (eqn:2) + P_prime_check = torch.matmul(Z_prime * G_prime.unsqueeze(-2), Z_prime) + diff_ratio = (P_prime - P_prime_check).abs().sum() / P_prime.abs().sum() + if diff_ratio > 0.01: + logging.warn(f"Z_prime does not satisfy its definition, diff_ratio = {diff_ratio}") + Z = torch.matmul(U_g, torch.matmul(Z_prime, U_g.transpose(2, 3))) + # OK, Z is the SPD transform that maps G to P, as in Z G Z = P. + # We just need the basis that diagonalizes this. + U_z, S, _ = _svd(Z) + if True: + skip = 10 if S.shape[-1] > 40 else 1 + logging.info(f"Eigs of Z are: {S[0,0,::skip]}") + + # state[f"Q_{dim}"] is indexed: [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate]. + # so we need to transpose U_z as U_z is indexed + # [batch_idx, block_idx, canonical_coordinate, diagonalized_coordinate] + Q[:] = U_z.transpose(2, 3) + + # Work out the projected smoothed parameter covariance P_proj, which is P + # projected in the basis U_z. Now, + # P = U_g P' U_g^T, + # and P_proj = U_z^T P U_z, + # so P_proj = (U_z^T U_g) P' (U_z^T U_g)^T + U_prod = torch.matmul(U_z.transpose(2, 3), U_g) + # this_P_proj shape: (batch_size, num_blocks, block_size) + this_P_proj = _diag(torch.matmul(U_prod, torch.matmul(P_prime, U_prod.transpose(2, 3)))) + P_proj[dim] = this_P_proj.reshape(batch_size, size) + return P_proj + + + def _smooth_param_cov(self, + group: dict, + p_shape: torch.Size, + P_prime: Tensor, + G_prime: Tensor) -> Tensor: """ This function returns a modified/smoothed version of the parameter covariance + P_prime. + Args: + group: dict to look up config values + p_shape: The shape of the parameter we are optimizing + P_prime: a Tensor of shape (batch_size, num_blocks, block_size, block_size), + containing the parameter covariance in a basis that diagonalizes the + gradient covariance. + G_prime: the diagonalized gradient covariance, of shape (batch_size, num_blocks, + block_size) + + state[f"param_cov_{dim}"], which is an estimate of the covariance of the parameter p, averaged over time, and taken over dimension `dim` of the tensor. - The smoothing done here limits the extend to which the parameter covariance + The smoothing done here limits the extent to which the parameter covariance can be strongly "off-diagonal" with respect to the gradient covariance. That is: if the parameter covariance is just the gradient covariance to some power, this function does no smoothing; but if it is highly off-diagonal we do more smoothing. """ - param_cov = state[f"param_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size) - grad_cov = state[f"grad_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size) - (batch_size, num_blocks, block_size, block_size) = param_cov.shape - U_g, _, _ = _svd(grad_cov) # U_g diagonalizes grad_cov, in the sense that U_g^T grad_cov U_g is diagonal. + P_prime_diag = _diag(P_prime) # (batch_size, num_blocks, block_size) + eps = 1.0e-10 + P_prime_diag = (P_prime_diag + eps) / P_prime_diag.mean() + # make sure no diagonal element is close to zero.. we don't expect this + # would happen. this is likely not important. Note, this just used for + # normalizing P prior to smoothing. + P_prime_diag.clamp_(min=0.01) + P_prime_rms = P_prime_diag.sqrt() + P_prime_scale = P_prime_rms.unsqueeze(-1) * P_prime_rms.unsqueeze(-2) - # param_cov_proj is param_cov in a different orthonormal basis, that diagonalizes - # grad_cov. - param_cov_proj = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g)) + # P_norm will have diagonal elements close to 1. We do some smoothing + # in this space. + P_norm = P_prime / P_prime_scale + # Now P is as normalized as we can make it... do smoothing baserd on 'rank', + # that is intended to compensate for bad estimates of P. + batch_size = p_shape[0] + size = P_prime.shape[0] # size of dim we are concerned with right now + # `rank` is the rank of P_prime if we were to estimate it from just one + # parameter tensor. We average it over time, but actually it won't be changing + # too much, so `rank` does tell us something. + rank = p_shape.numel() // (size * batch_size) + smooth0 = group["param_rms_smooth0"] + smooth1 = group["param_rms_smooth1"] + # We want expr for smoothing amount to be of the form: smooth = alpha * size / (beta*rank + size) + # param_rms_smooth{0,1} represents the user-specified desired amount of smoothing + # when rank==0*size and rank==1*size, respectively. + # from rank==0*size, we get smooth0 = alpha * size/size, so alpha = smooth0. + # from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta), + # so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1 + smooth = smooth0 * size / ((smooth0/smooth1 - 1) * rank + size) - # param_cov_eps is probably not critical, I don't expect to see super - # small values. apply as floor in case roundoff causes negative values. - param_cov_eps = 1.0e-05 - param_rms = _diag(param_cov_proj).clamp_(min=param_cov_eps).sqrt() - param_cov_inv_scale = param_rms.unsqueeze(-1) * param_rms.unsqueeze(-2) + # add rank-dependent smoothing amount to diagonal of P_prime. _diag() returns an aliased tensor. + # we don't need to multiply `smooth` by anything, because at this point, P_prime should have + # diagonal elements close to 1. + _diag(P_prime).add_(smooth) - # param_cov_norm should have diagonal values close to 1.0 (only not - # exactly 1.0 due to param_cov_eps and roundoff) - param_cov_norm = param_cov_proj / param_cov_inv_scale + P_norm = self._smooth_cov(P_norm, + group["min_lr_factor"][0], + group["max_lr_factor"][0]) + # Remove the diagonal preconditioning on P_norm, giving us stage-1-smoothed + # version of P_prime. + P_prime = P_norm * P_prime_scale - # OK, this is where we do smoothing. - # decompose param_cov_norm, which is symmetric, as U_p S U_p^T - U_p, S, _ = _svd(param_cov_norm) + # Make sure G_prime has unit mean and no eigenvalue is super small. Note, G_prime + # is already diagonal. + G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True) + G_prime_smooth = 0.001 + # make sure G_prime has no zero eigs, and is unit mean. + G_prime = ((G_prime + eps + G_prime_smooth * G_prime_mean) / + (G_prime_mean * (1+G_prime_smooth) + eps)) + G_prime_rms = G_prime.sqrt() + G_prime_scale = G_prime_rms.unsqueeze(-1) * G_prime_rms.unsqueeze(-2) + # P_gnorm is a version of P_prime that is scaled relative to G, i.e. + # scaled in such a way that would make G the unit matrix. + P_gnorm = P_prime / G_prime_scale + # Apply another round of smoothing "relative to G" + P_gnorm = self._smooth_cov(P_gnorm, + group["min_lr_factor"][1], + group["max_lr_factor"][1]) + # Undo the scaling relative to G, so we have stage-2-smoothed version of P_prime. + P_prime = P_gnorm * G_prime_scale + # Apply a 3rd round of smoothing + P_prime = self._smooth_cov(P_prime, + group["min_lr_factor"][2], + group["max_lr_factor"][2]) + return P_prime + def _smooth_cov(self, + X: Tensor, + min_eig: float, + max_eig: float, + power: float = 1.0) -> Tensor: + """ + Returns a `smoothed` version of a symmetric positive definite covariance matrix + [with block-diagonal structure, in a batch]. This is done without SVD (which + can be very slow). + The eigenvalues L will be transformed as: - residual_rms = S.sqrt() - # - relative_rms_pow = 0.7 - relative_rms_max = 4.0 + L = L + min_eig * L.mean() + eps + L /= L.mean() + L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig + L = L ** power # need SVD for this, will get rid of the requirement later. + L /= L.mean() - residual_rms = residual_rms ** relative_rms_pow - residual_rms /= _mean(residual_rms, exclude_dims=[0], keepdim=True) + # Note on approximation functions like x^0.75 for smallish x: on wolframalpha, type: + # plot x^0.75 and 0.05 + (1.1x - 0.18 x^2 + 0.02 x^3) for x from 0 to 10 + # [this starts to diverge after 5 or so] - if True: - # smooth according to the rank of the observation.. - size = p.shape[dim] - rank = p.numel() // (size * batch_size) - smooth0 = group["param_rms_smooth0"] - smooth1 = group["param_rms_smooth1"] - # want expr to be of the form: smooth = alpha * size / (beta*rank + size) - # from rank==0, we get smooth0 = alpha * size/size, so alpha = smooth0. - # from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta), - # so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1 - smooth = smooth0 * size / ((smooth0/smooth1 - 1) * rank + size) - - mean = _mean(residual_rms, exclude_dims=[0], keepdim=True) - residual_rms += group["eps"] + smooth * mean - residual_rms = residual_rms / _mean(residual_rms, exclude_dims=[0], keepdim=True) - - - # apply the maximum via a softmin function, softmin(x,y) = 1/(1/x + 1/y) - residual_rms = 1. / (1. / residual_rms + 1. / relative_rms_max) - - if random.random() < 0.1: - skip = 10 if S.shape[-1] > 40 else 1 - logging.info(f"Smoothed param_rms from {S.sqrt()[0,0,::skip]} to {residual_rms[0,0,::skip]}, param_rms={param_rms[0,0,::skip]}") - - # U shape: (batch_size, num_blocks, block_size, block_size), - # interpreted as - # residual_rms shape: (batch_size, num_blocks, block_size). - # so in terms of matrix multiplication, we are computing X_p = matmul(U_p, residual_rms.diag()) - X_p = U_p * residual_rms.unsqueeze(-2) - param_cov_norm_smoothed = torch.matmul(X_p, X_p.transpose(2, 3)) - - # Undo the scaling by the diagonal of param_cov - param_cov_proj_smoothed = param_cov_norm_smoothed * param_cov_inv_scale - - # Undo the projection by U. - param_cov_smoothed = torch.matmul(U_g, torch.matmul(param_cov_proj_smoothed, - U_g.transpose(2, 3))) - return param_cov_smoothed + Args: + min_eig: minimum allowed eigenvalue of returned X + max_eig: maximum allowed eigenvalue of returned X + power: power to take eigenvalues to + X: the batch of symmetric positive definite tensors we are smoothing; + of shape (batch_size, num_blocks, block_size, block_size) + """ + if power != 1.0: + U, S, _ = _svd(X) + eps = 1.0e-10 + S_mean = _mean(S, exclude_dims=[0], keepdim=True) + S = S + min_eig * S_mean + eps + S_mean = S_mean * (1 + min_eig) + eps + S = S / S_mean + S = 1. / (1./S + 1./max_eig) + S = S ** power + S = S / _mean(S, exclude_dims=[0], keepdim=True) + return torch.matmul(U * S.unsqueeze(-2), U.transpose(2, 3)) + else: + diag = _diag(X) # Aliased with X + mean_eig = _mean(diag, exclude_dims=[0], keepdim=True) + eps = 1.0e-10 # prevent division by zero + diag += (mean_eig * min_eig + eps) + cur_diag_mean = mean_eig * (1 + min_eig) + eps + # The following 2 statements will be equivalent to: + # L /= L.mean() + # L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig + # if L is the eigenvalues + X = (X.inverse() + 1/(max_eig * cur_diag_mean.unsqueeze(-1))).inverse() + X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1) + return X def _diagonalize_grad_cov(self, @@ -1144,22 +1321,27 @@ def _move_dim(x: Tensor, orig_dim: int, new_dim: int) -> Tensor: def _diag(x: Tensor): """ - like torch diag(), but supports batch dim, i.e. input of shape (B, M, M) returns + like torch diag(), this returns the diagonal of a matrix; but it returns an + aliased tensor, and it supports batch dims, + i.e. input of shape (B, M, M) returns output of shape (B, M), or input of shape (A, B, M, M) returns output of shape - (A, B, M) + (A, B, M). """ + stride = x.stride() if x.ndim == 3: (B, M, M2) = x.shape assert M == M2 - stride = x.stride() - return x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2])).contiguous() + ans = x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2])) elif x.ndim == 4: (B, C, M, M2) = x.shape assert M == M2 - stride = x.stride() - return x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3])).contiguous() - else: - return x.diag() + ans = x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3])).contiguous() + elif x.ndim == 2: + (M, M2) = x.shape + assert M == M2 + ans = x.as_strided(size=(M,), stride=(stride[0] + stride[1],)).contiguous() + return ans + def _sum(x: Tensor, exclude_dims: List[int] = [],