Some cleanups

This commit is contained in:
Daniel Povey 2022-08-18 07:03:06 +08:00
parent a5b9b7b974
commit 64f7166545

View File

@ -889,12 +889,17 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# rms of eigenvalues, or spectral 2-norm.
return _mean(_diag(Y), exclude_dims=[0], keepdim=True).unsqueeze(-1)
def rms_eig(Y):
# rms of eigenvalues, or spectral 2-norm.
return (_sum(Y*Y.transpose(2,3), exclude_dims=[0], keepdim=True) / size).sqrt()
def rms_eig(Y, M):
# rms of eigenvalues, or spectral 2-norm, of A = M^{0.5} Y M^{0.5}.
# We can compute this as (tr(A A) / size).sqrt(), since A is symmetric.
# tr(A ) == tr(M^{0.5} Y M^{0.5} M^{0.5} Y M^{0.5}) = tr(M Y M Y) = tr(Z Z) if Z == M Y.
# tr(Z Z) can be computed as sum(Z*Z.t())
Z = torch.matmul(Y, M)
return (_sum(Z*Z.transpose(2,3), exclude_dims=[0], keepdim=True) / size).sqrt()
# make sure eigs of M^{0.5} X M^{0.5} have rms value 1. This imposes limit on the max.
X /= rms_eig(torch.matmul(X, M))
# make sure eigs of M^{0.5} X M^{0.5} have rms value 1. This imposes a
# limit on the max.
X /= rms_eig(X, M)
# eig_ceil is the maximum possible eigenvalue that X could possibly
# have at this time, equal to num_blocks * block_size.
@ -928,6 +933,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
if min_eig != 0.0:
# conceptually we want tr(M^{0.5} X M^{0.5}), but this is the same as tr(X M).
# we can compute this as the sum of the elements of X * M^T, because if Y=matmul(X, M),
# Y_{ii} is the sum of X_{ij} * M_{ji}.
mean_XM_eig = _sum(X * M.transpose(2, 3), exclude_dims=[0], keepdim=True) / size
XM_over_X = mean_XM_eig / mean_eig(X)
X /= mean_XM_eig