diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index b7f010a83..b90bca057 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -107,9 +107,6 @@ class PrAdam(BatchedOptimizer): scale of each parameter tensor. If each parameter were decomposed as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, size_lr_scale is the scaling factor on the learning rate of p_scale. - param_pow: Power on the parameter covariance matrix, 1.0 means learn proportional - to parameter rms (1.0 will be too much, should be between 0 and 1.) - This is one of the most important tunable factors, along with param_cov_max. param_rms_smooth0: Limiting value of smoothing proportion for parameter matrix, as assumed rank of param covariance [==product of sizes on the other tensor dims] approaches 0. @@ -117,18 +114,21 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of param covariance equals the dimension of the covaraince matrix. param_rms_smooth{0,1} determine the smoothing proportions for other conditions. - param_cov_min: [IMPORTANT] A 3-tuple of minimums of the diagonal values of the parameter - covariance, normalized in 3 different ways: relative to its own - diagonal, scaled by the grad covariance, and in the canonical basis. - With param_cov_max, defines how "aggressive" we allow our update to - be. - param_cov_max: [IMPORTANT] A 3-tuple of maximums of the diagonal values of the parameter - covariance, normalized in 3 different ways: relative to its own - diagonal, scaled by the grad covariance, and in the canonical basis. - param_pow: This was mainly added for development and experimentation purposes; - it allows you to smooth the parameter covariance by taking it to - a power, but if this is not 1.0 it will cause a speed penalty because - it requires SVD. + cov_min,cov_max: [IMPORTANT] 4-tuples of minimums and maximums of the diagonal values of + covariance matrices, after normalizing to unit-mean. The first 3 are + for smoothing the parameter covariance, normalized in 3 different ways: + (1) relative to its own diagonal (in a basis that diagonalizes the grad + covariance); + (2) multiplied by the grad covariance, + (3) in the canonical basis. + + (4) is for smoothing the grad covariance used for (2) + cov_pow: This was mainly added for development and experimentation purposes; + it allows you to smooth the parameter covariance matrices at the + stages (1), (2), (3) of smoothing mentioned above, and also + the gradient covariance matrix used in stage (2) of smoothing. If the + 1st 3 values are not 1.0 it will cause a measurable speed penalty because it + requires SVD. Recommend to leave all these at 1.0. eps: A general-purpose epsilon to prevent division by zero param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of learning the scale on the parameters (we'll constrain the rms of each non-scalar @@ -159,9 +159,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of lr=3e-02, betas=(0.9, 0.98), size_lr_scale=0.1, - param_cov_min=(0.05, 0.01, 0.04), - param_cov_max=(10.0, 40.0, 5.0), - param_pow=(1.0, 1.0, 1.0), + cov_min=(0.05, 0.01, 0.04, 0.0001), + cov_max=(10.0, 40.0, 5.0, 400.0), + cov_pow=(1.0, 1.0, 1.0, 1.0), param_rms_smooth0=0.4, param_rms_smooth1=0.2, eps=1.0e-08, @@ -182,9 +182,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of defaults = dict( lr=lr, size_lr_scale=size_lr_scale, - param_cov_min=param_cov_min, - param_cov_max=param_cov_max, - param_pow=param_pow, + cov_min=cov_min, + cov_max=cov_max, + cov_pow=cov_pow, param_rms_smooth0=param_rms_smooth0, param_rms_smooth1=param_rms_smooth1, betas=betas, @@ -289,7 +289,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of p, memory_format=torch.preserve_format ) - ignore_rank1_dims = True # a config value, can change this (TODO: tune) + # ignore_rank1_dims = True means we don't do our basis-changing update + # on tensors like biases that only have one non-trivial dimension. + # "rank1" refers to the rank of our estimate of the parameter + # covariance, if we were to estimate it just from the current parameter + # value. + ignore_rank1_dims = True trivial_update = True for dim in range(1, p.ndim): @@ -325,7 +330,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of Q = torch.eye(block_size, block_size, **kwargs).unsqueeze(0).unsqueeze(0).expand( batch_size, num_blocks, block_size, block_size).contiguous() state[f"Q_{dim}"] = Q - # param_cov_{dim} is the averaged-over-time covariance of parameters on this dimension, treating + # cov_{dim} is the averaged-over-time covariance of parameters on this dimension, treating # all other dims as a batch axis. Also initialize as identity. state[f"param_cov_{dim}"] = Q.clone() @@ -1046,23 +1051,28 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of _diag(P_norm).add_(smooth) P_norm = self._smooth_cov(P_norm, - group["param_cov_min"][0], - group["param_cov_max"][0], - group["param_pow"][0]) + group["cov_min"][0], + group["cov_max"][0], + group["cov_pow"][0]) # Remove the diagonal preconditioning on P_norm, giving us stage-1-smoothed # version of P_prime. P_prime = P_norm * P_prime_scale - # Make sure G_prime has unit mean and no eigenvalue is super small. Note, G_prime - # is already diagonalized, the variable G_prime is just the tensor of eigenvalues. - G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True) - G_prime_smooth = 0.0001 - # make sure G_prime has no zero eigs, and is unit mean. - G_prime = ((G_prime + eps + G_prime_smooth * G_prime_mean) / - (G_prime_mean * (1+G_prime_smooth) + eps)) - # it now has unit mean.. - G_prime_max = 400.0 - G_prime = 1. / (1./G_prime + 1./G_prime_max) # apply max + if True: + # This block smooths G_prime. + # Make sure G_prime has unit mean and no eigenvalue is super small. Note, G_prime + # is already diagonalized, the variable G_prime is just the tensor of eigenvalues. + G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True) + G_prime_min = group["cov_min"][3] + # make sure G_prime has no zero eigs, and is unit mean. + G_prime = ((G_prime + eps + G_prime_min * G_prime_mean) / + (G_prime_mean * (1+G_prime_min) + eps)) + # it now has unit mean.. + G_prime_max = group["cov_max"][3] + G_prime = 1. / (1./G_prime + 1./G_prime_max) # apply max + G_prime_pow = group["cov_pow"][3] + if G_prime_pow != 1.0: + G_prime = G_prime ** G_prime_pow G_prime_rms = G_prime.sqrt() G_prime_scale = G_prime_rms.unsqueeze(-1) * G_prime_rms.unsqueeze(-2) @@ -1073,17 +1083,17 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of P_gnorm = P_prime * G_prime_scale # Apply another round of smoothing "relative to G" P_gnorm = self._smooth_cov(P_gnorm, - group["param_cov_min"][1], - group["param_cov_max"][1], - group["param_pow"][1]) + group["cov_min"][1], + group["cov_max"][1], + group["cov_pow"][1]) # Undo the scaling relative to G, so we have stage-2-smoothed version of P_prime. P_prime = P_gnorm / G_prime_scale # Apply a 3rd round of smoothing in the canonical basis. P_prime = self._smooth_cov(P_prime, - group["param_cov_min"][2], - group["param_cov_max"][2], - group["param_pow"][2]) + group["cov_min"][2], + group["cov_max"][2], + group["cov_pow"][2]) return P_prime def _smooth_cov(self, @@ -1378,66 +1388,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of while Q.ndim < M.ndim: Q = Q.unsqueeze(2) # now Q has shape (batch_size, num_blocks, 1, 1, block_size, block_size) - M = torch.matmul(M, Q) # (batch_size, num_blocks, x, y, z, block_size) + if M.ndim < Q.ndim: # special case where M shape is e.g. (batch_size, num_blocks, x) + M = torch.matmul(M.unsqueeze(-2), Q).squeeze(-2) + else: + M = torch.matmul(M, Q) # (batch_size, num_blocks, x, y, z, block_size) M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size) M = M.reshape(*M.shape[:-2], size) # (batch_size, x, y, z, size) M = M.transpose(-1, dim) # (batch_size, x, size, y, z) return M - def _smooth_param_rms(self, - group: dict, - rms: Tensor, - rank: int) -> Tensor: - """ - Smooths and normalizes to mean=1 a tensor of shape something like (batch_size, 1, size, 1, 1), - where for the nontrivial dim `size` it contains dagonalized parameter rms values; this is for - one dimension of a multiple-dimensional tensor. - - This will be used to construct a learning-rate matrix. - - Args: - group: dict for configuration values - rms: A tensor of shape (batch_size, [1,1,..], size, [1,1,1,..]), representing - (for each batch element) a list of root-mean-square values, one per - dimension of the space of size `size`. - rank: the assumed rank of the covariance matrix from which rms was derived, - used to decide how much smoothing to apply. This is actually the - minimal rank, if the parameter matrix were stay fixed during training, - but still relevant to know how robust the parameter covariance estimate is. - Returns: - a Tensor with the same shape as `rms` but with some smoothing applied so there - are no values too close to zero. - """ - smooth0 = group["param_rms_smooth0"] - smooth1 = group["param_rms_smooth1"] - param_cov_max = group["param_cov_max"] - param_pow = group["param_pow"] - eps = group["eps"] - batch_size = rms.shape[0] - size = rms.numel() // batch_size - rms = rms ** param_pow - - - # want expr to be of the form: smooth = alpha * size / (beta*rank + size) - # from rank==0, we get smooth0 = alpha * size/size, so alpha = smooth0. - # from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta), - # so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1 - smooth = smooth0 * size / ((smooth0/smooth1 - 1) * rank + size) - - mean = _mean(rms, exclude_dims=[0], keepdim=True) - rms += eps + smooth * mean - new_mean = (eps + (smooth + 1) * mean) # mean of modified rms. - ans = rms / new_mean - - - # Apply a `soft min` of param_cov_max via the formula - # softmin(x,y) = 1/(1/x + 1/y). - ans = 1. / (1. / ans + 1. / param_cov_max) - # and renormalize to mean=1. - ans /= _mean(ans, exclude_dims=[0], keepdim=True) - - return ans - def _move_dim(x: Tensor, orig_dim: int, new_dim: int) -> Tensor: """ @@ -2197,9 +2156,8 @@ def _test_eve_cain(): for iter in [3, 2]: fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear - # TODO: find out why this is not converging... - hidden_dim = 512 + hidden_dim = 768 m = torch.nn.Sequential(Linear(E, hidden_dim), torch.nn.PReLU(), Linear(hidden_dim, hidden_dim),