Swap the order of applying min and max in smoothing operations

This commit is contained in:
Daniel Povey 2022-08-02 11:55:43 +08:00
parent 9473c7e23d
commit e9f4ada1c0

View File

@ -820,33 +820,18 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
of shape (batch_size, num_blocks, block_size, block_size) of shape (batch_size, num_blocks, block_size, block_size)
""" """
eps = 1.0e-20 eps = 1.0e-20
if power != 1.0:
U, S, _ = _svd(X)
def mean(Y):
return _mean(Y, exclude_dims=[0], keepdim=True)
def rms(Y):
return _mean(Y**2, exclude_dims=[0], keepdim=True).sqrt()
S = S + min_eig * mean(S) + eps
S = S / rms(S)
S = 1. / (1./S + 1./max_eig)
S = S ** power
S = S / rms(S)
return torch.matmul(U * S.unsqueeze(-2), U.transpose(2, 3))
else:
X = X.clone() X = X.clone()
size = X.shape[1] * X.shape[3] size = X.shape[1] * X.shape[3]
def rms_eig(Y): def rms_eig(Y):
# rms of eigenvalues, or spectral 2-norm. # rms of eigenvalues, or spectral 2-norm.
return (_sum(Y**2, exclude_dims=[0], keepdim=True) / size).sqrt() return (_sum(Y**2, exclude_dims=[0], keepdim=True) / size).sqrt()
diag = _diag(X).unsqueeze(-1) # Aliased with X
diag += (_mean(diag, exclude_dims=[0], keepdim=True) * min_eig + eps)
X /= rms_eig(X)
X /= rms_eig(X)
# eig_ceil is the maximum possible eigenvalue that X could possibly # eig_ceil is the maximum possible eigenvalue that X could possibly
# have at this time, given that the RMS eig is 1, equal to sqrt(num_blocks * block_size) # have at this time, given that the RMS eig is 1, equal to sqrt(num_blocks * block_size)
eig_ceil = size ** 0.5 eig_ceil = size ** 0.5
# the next statement wslightly adjusts the target to be the same as # the next statement slightly adjusts the target to be the same as
# what the baseline function gives, max_eig -> 1./(1./eig_ceil + 1./max_eig) would # what the baseline function gives, max_eig -> 1./(1./eig_ceil + 1./max_eig) would
# give. so "max_eig" is now the target of the function for arg == # give. so "max_eig" is now the target of the function for arg ==
# eig_ceil. # eig_ceil.
@ -873,8 +858,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
X = X - coeff * torch.matmul(X, X.transpose(2, 3)) X = X - coeff * torch.matmul(X, X.transpose(2, 3))
# normalize to have rms eig == 1. # normalize to have rms eig == 1.
X /= rms_eig(X)
X = 0.5 * (X + X.transpose(2, 3)) # make sure exactly symmetric. X = 0.5 * (X + X.transpose(2, 3)) # make sure exactly symmetric.
diag = _diag(X).unsqueeze(-1) # Aliased with X
diag += (_mean(diag, exclude_dims=[0], keepdim=True) * min_eig + eps)
X /= rms_eig(X)
return X return X
@ -910,12 +899,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# make sure eigs of M^{0.5} X M^{0.5} have rms value 1. This imposes limit on the max. # make sure eigs of M^{0.5} X M^{0.5} have rms value 1. This imposes limit on the max.
X /= rms_eig(torch.matmul(X, M)) X /= rms_eig(torch.matmul(X, M))
if min_eig != 0.0:
X = X * (1.0-min_eig) + min_eig * M.inverse()
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
X /= rms_eig(torch.matmul(X, M))
# eig_ceil is the maximum possible eigenvalue that X could possibly # eig_ceil is the maximum possible eigenvalue that X could possibly
# have at this time, equal to num_blocks * block_size. # have at this time, equal to num_blocks * block_size.
eig_ceil = size ** 0.5 eig_ceil = size ** 0.5
@ -945,9 +928,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
coeff = (eig_ceil - max_eig) / (eig_ceil*eig_ceil) coeff = (eig_ceil - max_eig) / (eig_ceil*eig_ceil)
X = X - coeff * torch.matmul(X, torch.matmul(M, X.transpose(2, 3))) X = X - coeff * torch.matmul(X, torch.matmul(M, X.transpose(2, 3)))
# Normalize again.
X /= rms_eig(X)
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric. X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
if min_eig != 0.0:
X /= mean_eig(X)
X = X * (1.0-min_eig) + min_eig * M.inverse()
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
return X return X