From 64f7166545bed0f7b2fa5725ddd12aeab8e11db2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 18 Aug 2022 07:03:06 +0800 Subject: [PATCH] Some cleanups --- .../ASR/pruned_transducer_stateless7/optim.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 25cd4556d..578a390b6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -889,12 +889,17 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # rms of eigenvalues, or spectral 2-norm. return _mean(_diag(Y), exclude_dims=[0], keepdim=True).unsqueeze(-1) - def rms_eig(Y): - # rms of eigenvalues, or spectral 2-norm. - return (_sum(Y*Y.transpose(2,3), exclude_dims=[0], keepdim=True) / size).sqrt() + def rms_eig(Y, M): + # rms of eigenvalues, or spectral 2-norm, of A = M^{0.5} Y M^{0.5}. + # We can compute this as (tr(A A) / size).sqrt(), since A is symmetric. + # tr(A ) == tr(M^{0.5} Y M^{0.5} M^{0.5} Y M^{0.5}) = tr(M Y M Y) = tr(Z Z) if Z == M Y. + # tr(Z Z) can be computed as sum(Z*Z.t()) + Z = torch.matmul(Y, M) + return (_sum(Z*Z.transpose(2,3), exclude_dims=[0], keepdim=True) / size).sqrt() - # make sure eigs of M^{0.5} X M^{0.5} have rms value 1. This imposes limit on the max. - X /= rms_eig(torch.matmul(X, M)) + # make sure eigs of M^{0.5} X M^{0.5} have rms value 1. This imposes a + # limit on the max. + X /= rms_eig(X, M) # eig_ceil is the maximum possible eigenvalue that X could possibly # have at this time, equal to num_blocks * block_size. @@ -928,6 +933,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric. if min_eig != 0.0: + # conceptually we want tr(M^{0.5} X M^{0.5}), but this is the same as tr(X M). + # we can compute this as the sum of the elements of X * M^T, because if Y=matmul(X, M), + # Y_{ii} is the sum of X_{ij} * M_{ji}. mean_XM_eig = _sum(X * M.transpose(2, 3), exclude_dims=[0], keepdim=True) / size XM_over_X = mean_XM_eig / mean_eig(X) X /= mean_XM_eig