From b47433b77af68acc90821cc0ad2ecf3f42f19d1e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 23 Jul 2022 09:06:03 +0800 Subject: [PATCH] Fix bug in smooth_cov, for power==1.0 --- .../ASR/pruned_transducer_stateless7/optim.py | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index debc41890..566efb23f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -144,6 +144,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of size_lr_scale=0.1, min_lr_factor=(0.01, 0.01, 0.01), max_lr_factor=(10.0, 10.0, 10.0), + #param_pow=(0.99999, 0.99999, 0.99999), + param_pow=(1.0, 1.0, 1.0), param_rms_smooth0=0.75, param_rms_smooth1=0.25, eps=1.0e-08, @@ -163,6 +165,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of size_lr_scale=size_lr_scale, min_lr_factor=min_lr_factor, max_lr_factor=max_lr_factor, + param_pow=param_pow, param_rms_smooth0=param_rms_smooth0, param_rms_smooth1=param_rms_smooth1, betas=betas, @@ -678,7 +681,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # cur_scales for the other dims. cur_scales = [None] * ndim - debug = (random.random() < 0.001) + debug = (random.random() < 0.1) for i in range(4): # for 4 iterations (this is quite arbitrary) for dim in range(1, ndim): size = p.shape[dim] @@ -949,7 +952,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of P_norm = self._smooth_cov(P_norm, group["min_lr_factor"][0], - group["max_lr_factor"][0]) + group["max_lr_factor"][0], + group["param_pow"][0]) # Remove the diagonal preconditioning on P_norm, giving us stage-1-smoothed # version of P_prime. P_prime = P_norm * P_prime_scale @@ -969,14 +973,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # Apply another round of smoothing "relative to G" P_gnorm = self._smooth_cov(P_gnorm, group["min_lr_factor"][1], - group["max_lr_factor"][1]) + group["max_lr_factor"][1], + group["param_pow"][1]) # Undo the scaling relative to G, so we have stage-2-smoothed version of P_prime. P_prime = P_gnorm * G_prime_scale # Apply a 3rd round of smoothing P_prime = self._smooth_cov(P_prime, group["min_lr_factor"][2], - group["max_lr_factor"][2]) + group["max_lr_factor"][2], + group["param_pow"][2]) return P_prime def _smooth_cov(self, @@ -1007,9 +1013,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of X: the batch of symmetric positive definite tensors we are smoothing; of shape (batch_size, num_blocks, block_size, block_size) """ + eps = 1.0e-10 if power != 1.0: U, S, _ = _svd(X) - eps = 1.0e-10 S_mean = _mean(S, exclude_dims=[0], keepdim=True) S = S + min_eig * S_mean + eps S_mean = S_mean * (1 + min_eig) + eps @@ -1019,17 +1025,18 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of S = S / _mean(S, exclude_dims=[0], keepdim=True) return torch.matmul(U * S.unsqueeze(-2), U.transpose(2, 3)) else: - X = X.clone() # may be + X = X.clone() diag = _diag(X) # Aliased with X mean_eig = _mean(diag, exclude_dims=[0], keepdim=True) - eps = 1.0e-10 # prevent division by zero diag += (mean_eig * min_eig + eps) cur_diag_mean = mean_eig * (1 + min_eig) + eps # The following 2 statements will be equivalent to: # L /= L.mean() # L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig # if L is the eigenvalues - X = (X.inverse() + 1/(max_eig * cur_diag_mean.unsqueeze(-1))).inverse() + X_inv = X.inverse() + _diag(X_inv).add_(1. / (max_eig * cur_diag_mean)) + X = X_inv.inverse() X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1) return X @@ -2107,6 +2114,17 @@ def _test_eve_cain(): elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_block_size=100) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) + #TEMP + if iter == 3: + a = torch.randn(5, 10, 5, 5) + a = torch.matmul(a, a.transpose(2, 3)) + b1 = optim._smooth_cov(a, 0.1, 4.0, 0.999999999) + b2 = optim._smooth_cov(a, 0.1, 4.0, 1.0) + diff = (b1 - b2) + ratio = (diff**2).sqrt().mean() / (b1**2).sqrt().mean() + logging.info(f"ratio = {ratio}") + assert ratio < 0.01 + start = timeit.default_timer() avg_loss = 0.0 for epoch in range(150):