diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 135c266ef..33cf8f288 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -856,13 +856,34 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of mean_eig = _mean(diag, exclude_dims=[0], keepdim=True) diag += (mean_eig * min_eig + eps) cur_diag_mean = mean_eig * (1 + min_eig) + eps - # The following 2 statements will be equivalent to: - # L /= L.mean() - # L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig - # if L is the eigenvalues - X_inv = X.inverse() - _diag(X_inv).add_(1. / (max_eig * cur_diag_mean)) - X = X_inv.inverse() + X /= cur_diag_mean.unsqueeze(-1) + # OK, now the mean of the diagonal of X is 1 (or less than 1, in + # which case X is extremely tiny). + + # eig_ceil is the maximum possible eigenvalue that X could possibly + # have at this time. + eig_ceil = X.shape[-1] + + while max_eig <= eig_ceil * 0.5: + # this is equivalent to the operation: l -> l - 0.5/eig_ceil*l*l + # on eigenvalues, which maps eig_ceil to 0.5*eig_ceil and is monotonically + # increasing from 0..eig_ceil. + logging.info(f"X={X}, eig_ceil={eig_ceil}") + X = X - 0.5/eig_ceil * torch.matmul(X, X) + eig_ceil = 0.5 * eig_ceil + + # max_eig > eig_ceil * 0.5 + if max_eig < eig_ceil: + # map l -> l - coeff*l*l, if l==eig_ceil, this + # takes us to: + # eig_ceil - (eig_ceil-max_eig/(eig_ceil*eig_ceil))*eig_ceil*eig_ceil + # == max_eig + # .. and the fact that coeff <= 0.5/eig_ceil [since max_eig>eig_ceil*0.5] + # means that the function is monotonic on inputs from 0 to eig_ceil. + coeff = (eig_ceil - max_eig) / (eig_ceil*eig_ceil) + X = X - coeff * torch.matmul(X, X) + + # Normalize again. X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1) X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric. return X @@ -1120,7 +1141,7 @@ def _mean(x: Tensor, # the interface changes in future. return x.mean(dim=tuple(range(x.ndim)), keepdim=keepdim) elif x.ndim == 1: - assert exclude_dim == [0] or exclude_dim == [-1] + assert exclude_dims == [0] or exclude_dims == [-1] return x # if one dim is excluded, there are no dims to mean, and # x.mean(dim=[]) means all dims so we should not call mean(). exclude_dims_norm = [i if i >= 0 else i + x.ndim for i in exclude_dims] @@ -1806,7 +1827,7 @@ def _test_eve_cain(): fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear - hidden_dim = 200 + hidden_dim = 300 m = torch.nn.Sequential(Linear(E, hidden_dim), torch.nn.PReLU(), Linear(hidden_dim, hidden_dim), @@ -1823,16 +1844,6 @@ def _test_eve_cain(): elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_block_size=256) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) - #TEMP - if iter == 3: - a = torch.randn(5, 10, 5, 5) - a = torch.matmul(a, a.transpose(2, 3)) - b1 = optim._smooth_cov(a, 0.1, 4.0, 0.999999999) - b2 = optim._smooth_cov(a, 0.1, 4.0, 1.0) - diff = (b1 - b2) - ratio = (diff**2).sqrt().mean() / (b1**2).sqrt().mean() - logging.info(f"ratio = {ratio}") - assert ratio < 0.01 start = timeit.default_timer() avg_loss = 0.0 @@ -1891,12 +1902,22 @@ def _test_svd(): X2 = torch.matmul(U*S, V.t()) assert torch.allclose(X2, X, atol=0.001) + +def _test_smooth_cov(): + b = (torch.arange(10)*0.5).exp().diag() + b = b.unsqueeze(0).unsqueeze(0) + c = PrAdam._smooth_cov(None, b, 0.0, 10.0) + logging.info(f"c[noinv] = {_diag(c)}") + c = PrAdam._smooth_cov(None, b, 0.0, 10.0, 1.00001) + logging.info(f"c[svd] = {_diag(c)}") + if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) logging.getLogger().setLevel(logging.INFO) import subprocess s = subprocess.check_output("git status -uno .; git log -1", shell=True) + _test_smooth_cov() logging.info(s) #_test_svd() _test_eve_cain()