diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 33cf8f288..beff77c49 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -163,7 +163,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of lr=3e-02, betas=(0.9, 0.98), size_lr_scale=0.1, - cov_min=(0.025, 0.0025, 0.02, 0.0001), + cov_min=(0.025, 0.0, 0.02, 0.0001), cov_max=(10.0, 80.0, 5.0, 400.0), cov_pow=(1.0, 1.0, 1.0, 1.0), param_rms_smooth0=0.4, @@ -669,7 +669,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of if True: _,S,_ = _svd(P) - logging.info(f"Eigs of P are {S}") + logging.info(f"Eigs of P are {S[0][0]}") # C will satisfy: P == torch.matmul(C, C.transpose(2, 3)) @@ -759,49 +759,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of P_diag_mean = _mean(P_diag, exclude_dims=[0], keepdim=True) P_diag += smooth * P_diag_mean - G = G.clone() - G_diag = _diag(G) # aliased - G_diag *= 1.005 # improve its condition, for numerical reasons. + #G = G.clone() + #G_diag = _diag(G) # aliased G = self._smooth_cov(G, group["cov_min"][3], group["cov_max"][3], group["cov_pow"][3]) - # C C^T == G. - C = G.cholesky() - P_orig = P.clone() - - # treat the last dim of C as being in an arbitrary space, its next-to-last dim - # is the "canonical" one that we need to sum with the dims of P. - P_gnorm = torch.matmul(C.transpose(2, 3), - torch.matmul(P, C)) - - P_gnorm = self._smooth_cov(P_gnorm, - group["cov_min"][1], - group["cov_max"][1], - group["cov_pow"][1]) - - #torch.triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None) - #In symbols, it solves AX = b, ... X = A^{-1} b - #and assumes A is square upper-triangular (or lower-triangular if upper= False) and does not have zeros on the diagonal. - # - # We want to invert the normalization by C to recover P (if there weren't smoothing), so we want - # P = torch.matmul(C_inv.transpose(2, 3), - # torch.matmul(P_gnorm, C_inv)) - - # the following is equivalent to: P_temp = torch.matmul(C_inv.transpose(2, 3), P_gnorm), - # where C_inv = C.invert() - P_temp = torch.triangular_solve(P_gnorm, C, upper=False, transpose=True)[0] - # .. now we want to do the same on the other axis, need transpose. - P = torch.triangular_solve(P_temp.transpose(2, 3), C, - upper=False, transpose=True)[0].transpose(2, 3) - - if True: # TEMP - C_inv = C.inverse() - P2 = torch.matmul(C_inv.transpose(2, 3), - torch.matmul(P_gnorm, C_inv)) - assert (P-P2).norm() <= 0.001 * P.norm() + P = self._apply_max_with_metric(P, G, + group["cov_max"][1]) # Apply a 3rd round of smoothing in the canonical basis. @@ -864,11 +831,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # have at this time. eig_ceil = X.shape[-1] + # the next statement wslightly adjusts the target to be the same as + # what the baseline function, eig -> 1./(1./eig + 1./max_eig) would + # give. so "max_eig" is now the target of the function for arg == + # eig_ceil. + max_eig = 1. / (1. / max_eig + 1. / eig_ceil) + while max_eig <= eig_ceil * 0.5: # this is equivalent to the operation: l -> l - 0.5/eig_ceil*l*l # on eigenvalues, which maps eig_ceil to 0.5*eig_ceil and is monotonically # increasing from 0..eig_ceil. - logging.info(f"X={X}, eig_ceil={eig_ceil}") X = X - 0.5/eig_ceil * torch.matmul(X, X) eig_ceil = 0.5 * eig_ceil @@ -889,6 +861,59 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of return X + def _apply_max_with_metric(self, + X: Tensor, + M: Tensor, + max_eig: float) -> Tensor: + """ + Smooths X with maximum eigenvalue (relative to the mean) relative to + metric M. Equivalent to applying + Y := M^{0.5} X M^{0.5} + Apply maximum to eigenvalues of Y, as done in _smooth_cov() + X := M^{-0.5} Y M^{-0.5} + + Args: + X: Batch of block-diagonal nonzero positive definite matrices to have max eigenvalue applied + M: Batch of positive definite block-diagonal matrices to use as metrics w.r.t. X. + """ + X = X.clone() + # mean eig of M^{0.5} X M^{0.5} ... + mean_eig = (X*M).sum(dim=(2,3), keepdim=True) / X.shape[-1] + # make sure eigs of M^{0.5} X M^{0.5} are average 1. this imposes limit on the max. + X /= mean_eig + + eig_ceil = X.shape[-1] + + # the next statement wslightly adjusts the target to be the same as + # what the baseline function, eig -> 1./(1./eig + 1./max_eig) would + # give. so "max_eig" is now the target of the function for arg == + # eig_ceil. + max_eig = 1. / (1. / max_eig + 1. / eig_ceil) + + while max_eig <= eig_ceil * 0.5: + # this is equivalent to the operation: l -> l - 0.5/eig_ceil*l*l + # on eigenvalues, which maps eig_ceil to 0.5*eig_ceil and is monotonically + # increasing from 0..eig_ceil. + #logging.info(f"X={X}, eig_ceil={eig_ceil}") + X = X - 0.5/eig_ceil * torch.matmul(X, torch.matmul(M, X)) + eig_ceil = 0.5 * eig_ceil + + # max_eig > eig_ceil * 0.5 + if max_eig < eig_ceil: + # map l -> l - coeff*l*l, if l==eig_ceil, this + # takes us to: + # eig_ceil - (eig_ceil-max_eig/(eig_ceil*eig_ceil))*eig_ceil*eig_ceil + # == max_eig + # .. and the fact that coeff <= 0.5/eig_ceil [since max_eig>eig_ceil*0.5] + # means that the function is monotonic on inputs from 0 to eig_ceil. + coeff = (eig_ceil - max_eig) / (eig_ceil*eig_ceil) + X = X - coeff * torch.matmul(X, torch.matmul(M, X)) + + # Normalize again. + X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1) + X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric. + return X + def _update_grad_cov(self, group: dict, @@ -1827,7 +1852,7 @@ def _test_eve_cain(): fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear - hidden_dim = 300 + hidden_dim = 200 m = torch.nn.Sequential(Linear(E, hidden_dim), torch.nn.PReLU(), Linear(hidden_dim, hidden_dim),