mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
1st draft of new method of normalizing covs that uses normalization w.r.t. spectral 2-norm
This commit is contained in:
parent
4919134a94
commit
6ab4cf615d
@ -164,7 +164,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
betas=(0.9, 0.98),
|
betas=(0.9, 0.98),
|
||||||
size_lr_scale=0.1,
|
size_lr_scale=0.1,
|
||||||
cov_min=(0.025, 0.0025, 0.02, 0.0001),
|
cov_min=(0.025, 0.0025, 0.02, 0.0001),
|
||||||
cov_max=(10.0, 80.0, 5.0, 400.0),
|
cov_max=(10.0, 10.0, 5.0, 20.0),
|
||||||
cov_pow=(1.0, 1.0, 1.0, 1.0),
|
cov_pow=(1.0, 1.0, 1.0, 1.0),
|
||||||
param_rms_smooth0=0.4,
|
param_rms_smooth0=0.4,
|
||||||
param_rms_smooth1=0.2,
|
param_rms_smooth1=0.2,
|
||||||
@ -671,6 +671,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
_,S,_ = _svd(P)
|
_,S,_ = _svd(P)
|
||||||
logging.info(f"Eigs of P are {S[0][0]}")
|
logging.info(f"Eigs of P are {S[0][0]}")
|
||||||
|
|
||||||
|
# Re-normalize P so that the mean eig of each block-diagonal matrix
|
||||||
|
# is 1 (as opposed to rms == 1). To make C norm-preserving when
|
||||||
|
# applied to normally distributed input, we want its rms(singular
|
||||||
|
# value) to be 1, and since the singular values (==eigenvalues) of P
|
||||||
|
# are the squares of the singular values of C, we want them to have
|
||||||
|
# mean of 1.
|
||||||
|
P /= _mean(_diag(P), exclude_dims=[0]).unsqueeze(-1)
|
||||||
|
|
||||||
# C will satisfy: P == torch.matmul(C, C.transpose(2, 3))
|
# C will satisfy: P == torch.matmul(C, C.transpose(2, 3))
|
||||||
# C is of shape (batch_size, num_blocks, block_size, block_size).
|
# C is of shape (batch_size, num_blocks, block_size, block_size).
|
||||||
@ -811,30 +818,30 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
eps = 1.0e-20
|
eps = 1.0e-20
|
||||||
if power != 1.0:
|
if power != 1.0:
|
||||||
U, S, _ = _svd(X)
|
U, S, _ = _svd(X)
|
||||||
S_mean = _mean(S, exclude_dims=[0], keepdim=True)
|
def rms(Y):
|
||||||
S = S + min_eig * S_mean + eps
|
return _mean(Y**2, exclude_dims=[0], keepdim=True).sqrt()
|
||||||
S_mean = S_mean * (1 + min_eig) + eps
|
S = S + min_eig * rms(S) + eps
|
||||||
S = S / S_mean
|
S = S / rms(S)
|
||||||
S = 1. / (1./S + 1./max_eig)
|
S = 1. / (1./S + 1./max_eig)
|
||||||
S = S ** power
|
S = S ** power
|
||||||
S = S / _mean(S, exclude_dims=[0], keepdim=True)
|
S = S / rms(S)
|
||||||
return torch.matmul(U * S.unsqueeze(-2), U.transpose(2, 3))
|
return torch.matmul(U * S.unsqueeze(-2), U.transpose(2, 3))
|
||||||
else:
|
else:
|
||||||
X = X.clone()
|
X = X.clone()
|
||||||
diag = _diag(X) # Aliased with X
|
size = X.shape[1] * X.shape[3]
|
||||||
mean_eig = _mean(diag, exclude_dims=[0], keepdim=True)
|
def rms_eig(Y):
|
||||||
diag += (mean_eig * min_eig + eps)
|
# rms of eigenvalues, or spectral 2-norm.
|
||||||
cur_diag_mean = mean_eig * (1 + min_eig) + eps
|
return (_sum(Y**2, exclude_dims=[0], keepdim=True) / size).sqrt()
|
||||||
X /= cur_diag_mean.unsqueeze(-1)
|
diag = _diag(X).unsqueeze(-1) # Aliased with X
|
||||||
# OK, now the mean of the diagonal of X is 1 (or less than 1, in
|
diag += (rms_eig(X) * min_eig + eps)
|
||||||
# which case X is extremely tiny).
|
X /= rms_eig(X)
|
||||||
|
|
||||||
# eig_ceil is the maximum possible eigenvalue that X could possibly
|
# eig_ceil is the maximum possible eigenvalue that X could possibly
|
||||||
# have at this time, equal to num_blocks * block_size.
|
# have at this time, given that the RMS eig is 1, equal to sqrt(num_blocks * block_size)
|
||||||
eig_ceil = X.shape[1] * X.shape[3]
|
eig_ceil = size ** 0.5
|
||||||
|
|
||||||
# the next statement wslightly adjusts the target to be the same as
|
# the next statement wslightly adjusts the target to be the same as
|
||||||
# what the baseline function, eig -> 1./(1./eig + 1./max_eig) would
|
# what the baseline function gives, max_eig -> 1./(1./eig_ceil + 1./max_eig) would
|
||||||
# give. so "max_eig" is now the target of the function for arg ==
|
# give. so "max_eig" is now the target of the function for arg ==
|
||||||
# eig_ceil.
|
# eig_ceil.
|
||||||
max_eig = 1. / (1. / max_eig + 1. / eig_ceil)
|
max_eig = 1. / (1. / max_eig + 1. / eig_ceil)
|
||||||
@ -859,8 +866,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
coeff = (eig_ceil - max_eig) / (eig_ceil*eig_ceil)
|
coeff = (eig_ceil - max_eig) / (eig_ceil*eig_ceil)
|
||||||
X = X - coeff * torch.matmul(X, X.transpose(2, 3))
|
X = X - coeff * torch.matmul(X, X.transpose(2, 3))
|
||||||
|
|
||||||
# Normalize again.
|
# normalize to have rms eig == 1.
|
||||||
X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1)
|
X /= rms_eig(X)
|
||||||
X = 0.5 * (X + X.transpose(2, 3)) # make sure exactly symmetric.
|
X = 0.5 * (X + X.transpose(2, 3)) # make sure exactly symmetric.
|
||||||
return X
|
return X
|
||||||
|
|
||||||
@ -885,9 +892,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# size of the block-diagonal matrix..
|
# size of the block-diagonal matrix..
|
||||||
size = X.shape[1] * X.shape[3]
|
size = X.shape[1] * X.shape[3]
|
||||||
# mean eig of M^{0.5} X M^{0.5} ...
|
# mean eig of M^{0.5} X M^{0.5} ...
|
||||||
mean_eig = _sum(X*M, exclude_dims=[0], keepdim=True) / size
|
def rms_eig(Y):
|
||||||
# make sure eigs of M^{0.5} X M^{0.5} are average 1. this imposes limit on the max.
|
# rms of eigenvalues, or spectral 2-norm.
|
||||||
X /= mean_eig
|
return (_sum(Y**2, exclude_dims=[0], keepdim=True) / size).sqrt()
|
||||||
|
|
||||||
|
# make sure eigs of M^{0.5} X M^{0.5} have rms value 1. This imposes limit on the max.
|
||||||
|
X /= rms_eig(torch.matmul(X, M))
|
||||||
|
|
||||||
if min_eig != 0.0:
|
if min_eig != 0.0:
|
||||||
X = X * (1.0-min_eig) + min_eig * M.inverse()
|
X = X * (1.0-min_eig) + min_eig * M.inverse()
|
||||||
@ -895,7 +905,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
|
|
||||||
# eig_ceil is the maximum possible eigenvalue that X could possibly
|
# eig_ceil is the maximum possible eigenvalue that X could possibly
|
||||||
# have at this time, equal to num_blocks * block_size.
|
# have at this time, equal to num_blocks * block_size.
|
||||||
eig_ceil = size
|
eig_ceil = size ** 0.5
|
||||||
|
|
||||||
# the next statement wslightly adjusts the target to be the same as
|
# the next statement wslightly adjusts the target to be the same as
|
||||||
# what the baseline function, eig -> 1./(1./eig + 1./max_eig) would
|
# what the baseline function, eig -> 1./(1./eig + 1./max_eig) would
|
||||||
@ -923,7 +933,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
X = X - coeff * torch.matmul(X, torch.matmul(M, X.transpose(2, 3)))
|
X = X - coeff * torch.matmul(X, torch.matmul(M, X.transpose(2, 3)))
|
||||||
|
|
||||||
# Normalize again.
|
# Normalize again.
|
||||||
X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1)
|
X /= rms_eig(X)
|
||||||
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
|
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
|
||||||
return X
|
return X
|
||||||
|
|
||||||
@ -1842,7 +1852,7 @@ def _test_eden():
|
|||||||
logging.info(f"state dict = {scheduler.state_dict()}")
|
logging.info(f"state dict = {scheduler.state_dict()}")
|
||||||
|
|
||||||
|
|
||||||
def _test_eve_cain(hidden_dim):
|
def _test_eve_cain(hidden_dim: int):
|
||||||
import timeit
|
import timeit
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
E = 100
|
E = 100
|
||||||
@ -1953,15 +1963,15 @@ if __name__ == "__main__":
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
import subprocess
|
import subprocess
|
||||||
|
s = subprocess.check_output("git status -uno .; git log -1", shell=True)
|
||||||
|
_test_smooth_cov()
|
||||||
|
logging.info(s)
|
||||||
|
#_test_svd()
|
||||||
import sys
|
import sys
|
||||||
if len(sys.argv) > 1:
|
if len(sys.argv) > 1:
|
||||||
hidden_dim = int(sys.argv[1])
|
hidden_dim = int(sys.argv[1])
|
||||||
else:
|
else:
|
||||||
hidden_dim = 200
|
hidden_dim = 200
|
||||||
s = subprocess.check_output("git status -uno .; git log -1", shell=True)
|
|
||||||
_test_smooth_cov()
|
|
||||||
logging.info(f"hidden_dim = {hidden_dim}")
|
|
||||||
logging.info(s)
|
|
||||||
#_test_svd()
|
|
||||||
_test_eve_cain(hidden_dim)
|
_test_eve_cain(hidden_dim)
|
||||||
#_test_eden()
|
#_test_eden()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user