mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Reworked how inverse is done, fixed bug in _apply_min_max_with_metric, regarding how M is normalized.
This commit is contained in:
parent
766bf69a98
commit
dc9133227f
@ -124,12 +124,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
|
||||
(4) is for smoothing the grad covariance used for (2)
|
||||
|
||||
cov_pow: This was mainly added for development and experimentation purposes;
|
||||
it allows you to smooth the parameter covariance matrices at the
|
||||
stages (1), (2), (3) of smoothing mentioned above, and also
|
||||
the gradient covariance matrix used in stage (2) of smoothing. If the
|
||||
1st 3 values are not 1.0 it will cause a measurable speed penalty because it
|
||||
requires SVD. Recommend to leave all these at 1.0.
|
||||
eps: A general-purpose epsilon to prevent division by zero
|
||||
denom_rel_eps: An epsilon value used to keep the elements of denominator (exp_avg_grad_sqrt.sqrt())
|
||||
not too small relative to the mean value; this will limit the speed of update
|
||||
@ -163,9 +157,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
lr=3e-02,
|
||||
betas=(0.9, 0.98),
|
||||
size_lr_scale=0.1,
|
||||
cov_min=(0.025, 0.0025, 0.1, 0.0001),
|
||||
cov_min=(0.025, 0.05, 0.1, 0.0001),
|
||||
cov_max=(5.0, 20.0, 3.5, 40.0),
|
||||
cov_pow=(1.0, 1.0, 1.0, 1.0),
|
||||
param_rms_smooth0=0.4,
|
||||
param_rms_smooth1=0.2,
|
||||
eps=1.0e-08,
|
||||
@ -189,7 +182,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
size_lr_scale=size_lr_scale,
|
||||
cov_min=cov_min,
|
||||
cov_max=cov_max,
|
||||
cov_pow=cov_pow,
|
||||
param_rms_smooth0=param_rms_smooth0,
|
||||
param_rms_smooth1=param_rms_smooth1,
|
||||
betas=betas,
|
||||
@ -763,8 +755,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
|
||||
P = self._smooth_cov(P,
|
||||
max(smooth, group["cov_min"][0]),
|
||||
group["cov_max"][0],
|
||||
group["cov_pow"][0])
|
||||
group["cov_max"][0])
|
||||
#P = P.clone()
|
||||
#P_diag = _diag(P)
|
||||
#P_diag_mean = _mean(P_diag, exclude_dims=[0], keepdim=True)
|
||||
@ -775,11 +766,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
#G_diag *= 1.005 # ensure invertible.
|
||||
G = self._smooth_cov(G,
|
||||
group["cov_min"][3],
|
||||
group["cov_max"][3],
|
||||
group["cov_pow"][3])
|
||||
group["cov_max"][3])
|
||||
|
||||
G_inv = self._safe_inverse(G, group["cov_min"][3])
|
||||
|
||||
P = self._apply_min_max_with_metric(P, G,
|
||||
P = self._apply_min_max_with_metric(P, G, G_inv,
|
||||
group["cov_min"][1],
|
||||
group["cov_max"][1])
|
||||
|
||||
@ -787,36 +778,30 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# Apply a 3rd round of smoothing in the canonical basis.
|
||||
P = self._smooth_cov(P,
|
||||
group["cov_min"][2],
|
||||
group["cov_max"][2],
|
||||
group["cov_pow"][2])
|
||||
group["cov_max"][2])
|
||||
return P
|
||||
|
||||
def _smooth_cov(self,
|
||||
X: Tensor,
|
||||
min_eig: float,
|
||||
max_eig: float,
|
||||
power: float = 1.0) -> Tensor:
|
||||
max_eig: float) -> Tensor:
|
||||
"""
|
||||
Returns a `smoothed` version of a symmetric positive definite covariance matrix
|
||||
[with block-diagonal structure, in a batch]. This is done without SVD (which
|
||||
can be very slow).
|
||||
The eigenvalues L will be transformed as:
|
||||
Returns a `smoothed` version of a nonzero symmetric positive semidefinite covariance matrix
|
||||
(actually a batch of such covariance matrices, each one with block-diagonal structure).
|
||||
This is done without SVD or eigenvalue decomposition or inversion, since those things
|
||||
can be very slow).
|
||||
|
||||
L = L + min_eig * L.mean() + eps
|
||||
L /= L.mean()
|
||||
L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig
|
||||
L = L ** power # need SVD for this, will get rid of the requirement later.
|
||||
L /= L.mean()
|
||||
The eigenvalues L will be transformed in a way roughly approximating the following:
|
||||
|
||||
# Note on approximation functions like x^0.75 for smallish x: on wolframalpha, type:
|
||||
# plot x^0.75 and 0.05 + (1.1x - 0.18 x^2 + 0.02 x^3) for x from 0 to 10
|
||||
# [this starts to diverge after 5 or so]
|
||||
L /= (L**2).mean().sqrt() # normalize so RMS eigenvalue is 1.0
|
||||
L = 1 / (1/L + 1/max_eig) # apply soft-min with max_eig as the maximum
|
||||
L /= (1.0 - min_eig) * L.mean()
|
||||
L += min_eig
|
||||
|
||||
Args:
|
||||
min_eig: minimum allowed eigenvalue of returned X
|
||||
max_eig: maximum allowed eigenvalue of returned X
|
||||
power: power to take eigenvalues to
|
||||
X: the batch of symmetric positive definite tensors we are smoothing;
|
||||
X: the batch of symmetric positive semidefinite tensors we are smoothing;
|
||||
of shape (batch_size, num_blocks, block_size, block_size)
|
||||
"""
|
||||
eps = 1.0e-20
|
||||
@ -824,7 +809,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
size = X.shape[1] * X.shape[3]
|
||||
def rms_eig(Y):
|
||||
# rms of eigenvalues, or spectral 2-norm.
|
||||
return (_sum(Y**2, exclude_dims=[0], keepdim=True) / size).sqrt()
|
||||
return (_sum(Y**2, exclude_dims=[0], keepdim=True) / size + eps).sqrt()
|
||||
|
||||
X /= rms_eig(X)
|
||||
# eig_ceil is the maximum possible eigenvalue that X could possibly
|
||||
@ -860,16 +845,18 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# normalize to have rms eig == 1.
|
||||
X = 0.5 * (X + X.transpose(2, 3)) # make sure exactly symmetric.
|
||||
|
||||
diag = _diag(X).unsqueeze(-1) # Aliased with X
|
||||
diag += (_mean(diag, exclude_dims=[0], keepdim=True) * min_eig + eps)
|
||||
X /= rms_eig(X)
|
||||
def mean_eig(Y):
|
||||
return _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1)
|
||||
|
||||
X *= ((1.0 - min_eig) / (mean_eig(X) + eps))
|
||||
_diag(X).add_(min_eig)
|
||||
return X
|
||||
|
||||
|
||||
def _apply_min_max_with_metric(self,
|
||||
X: Tensor,
|
||||
M: Tensor,
|
||||
M_inverse: Tensor,
|
||||
min_eig: float,
|
||||
max_eig: float) -> Tensor:
|
||||
"""
|
||||
@ -882,6 +869,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
Args:
|
||||
X: Batch of block-diagonal nonzero positive definite matrices to have max eigenvalue applied
|
||||
M: Batch of positive definite block-diagonal matrices to use as metrics w.r.t. X.
|
||||
M_inverse: If min_eig != 0.0, the inverse of M should be supplied by the user;
|
||||
this is expected to have about the same value that M.inverse() would have,
|
||||
but is supplied separately as the the calling code inverts in a safer-than-normal
|
||||
way based on knowledge of the minimum eigenvalue of M.
|
||||
min_eig: The minimum eigenvalue of the returned Tensor, which will have average
|
||||
eigenvalue equal to 1.0 for each block-diagonal matrix.
|
||||
max_eig: The maximum allowed eigenvalue of the returned Tensor; this is expressed
|
||||
as a multiple of the root-mean-square (rms) eigenvalue of X, and we
|
||||
later normalize and apply min_eig, so the returned tensor may have
|
||||
eigenvalues larger than this.
|
||||
"""
|
||||
X = X.clone()
|
||||
# size of the block-diagonal matrix..
|
||||
@ -894,7 +891,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
|
||||
def rms_eig(Y):
|
||||
# rms of eigenvalues, or spectral 2-norm.
|
||||
return (_sum(Y**2, exclude_dims=[0], keepdim=True) / size).sqrt()
|
||||
return (_sum(Y*Y.transpose(2,3), exclude_dims=[0], keepdim=True) / size).sqrt()
|
||||
|
||||
# make sure eigs of M^{0.5} X M^{0.5} have rms value 1. This imposes limit on the max.
|
||||
X /= rms_eig(torch.matmul(X, M))
|
||||
@ -931,12 +928,79 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
|
||||
|
||||
if min_eig != 0.0:
|
||||
X /= mean_eig(X)
|
||||
X = X * (1.0-min_eig) + min_eig * M.inverse()
|
||||
mean_XM_eig = _sum(X * M.transpose(2, 3), exclude_dims=[0], keepdim=True) / size
|
||||
X /= mean_XM_eig
|
||||
X = X * (1.0-min_eig) + min_eig * M_inverse
|
||||
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
|
||||
|
||||
return X
|
||||
|
||||
def _safe_inverse(self,
|
||||
X: Tensor,
|
||||
min_eig: float,
|
||||
tolerance: float = 1.2):
|
||||
"""
|
||||
Returns X.inverse(), but in a safe way that limits the possible amount that numerical
|
||||
roundoff could make the result inaccurate.
|
||||
|
||||
Args:
|
||||
X: batch of positive definite tensors to be inverted, of shape (*, size, size).
|
||||
Mathematically, X must be symmetric and positive definite (SPD), but this
|
||||
routine allows it to have negative eigenvalues and just assumes that
|
||||
any eigenvalues <min_eig should actually equal min_eig.
|
||||
min_eig: the minimum possible eigenvalue that X is known mathematically to have
|
||||
(e.g. 0.1).
|
||||
tolerance: A tolerance that should be >1.0; it controls the tradeoff betweeen
|
||||
redundant computation and accuracy.
|
||||
"""
|
||||
try:
|
||||
size = X.shape[-1]
|
||||
I = torch.eye(size, dtype=X.dtype, device=X.device)
|
||||
if True:
|
||||
X_inv = X.inverse()
|
||||
else:
|
||||
C = X.cholesky() # X = C C^T
|
||||
I = I.reshape(*([1]*(C.ndim-2)), size, size).expand(*X.shape)
|
||||
|
||||
# note, X^{-1} == C^{-T} C^{-1}
|
||||
# next line does: X_inv = C^{-1} X_inv (== C^{-1})
|
||||
X_inv = torch.triangular_solve(I, C, upper=False, transpose=False)[0]
|
||||
# next line does: X_inv = C^{-T} X_inv (== C^{-T} C^{-1})
|
||||
X_inv = torch.triangular_solve(X_inv, C, upper=False, transpose=True)[0]
|
||||
|
||||
# Make sure X_inv is exactly symmetric (the input is assumed symmetric)
|
||||
X_inv = 0.5 * (X_inv + X_inv.transpose(-2, -1))
|
||||
|
||||
# OK, X_inv is positive definite and we require that its largest eigenvalue
|
||||
# should be < (1/min_eig) * tolerance. That requires that
|
||||
# I * tolerance - X_inv should be positive definite.
|
||||
X_inv_check = ((tolerance / min_eig) * I - X_inv)
|
||||
# now X_inv_check == I * tolerance - X_inv.
|
||||
# We don't need the result of X_inv_check.cholesky() below, we just execute the
|
||||
# statement as a way of checking that X_inv_check is positive definite.
|
||||
# we also check that X_inv has no negative eigenvalues, because if any of the
|
||||
# original X's eigenvalues had been negative, they would show up as negative
|
||||
# eigenvalues in the answer.
|
||||
torch.stack([X_inv, X_inv_check]).cholesky()
|
||||
return X_inv
|
||||
except RuntimeError as e:
|
||||
# _svd(X) is an alternative to X.svd() that retries in response to
|
||||
# failures caused by SVD implementation bugs.
|
||||
U,S,V = _svd(X)
|
||||
S_sign = (2*((U * V).sum(dim=-2) > 0) - 1)
|
||||
# the method with S_with_sign below is a quick but not 100%
|
||||
# guaranteed to find the eigenvalues of a SPD matrix (might not work
|
||||
# if there are tied eigenvalues); this does not matter as we are only
|
||||
# using it to print diagnostics.
|
||||
S_with_sign = S * S_sign
|
||||
logging.warning(f"Caught error in _safe_inverse() (this is normal). Shape={X.shape}, minimum eigenvalue "
|
||||
f"is {S_with_sign.min().item()} vs. {min_eig} with tolerance {tolerance}; exception was: {e}")
|
||||
S_inv = 1.0 / S.clamp(min=min_eig)
|
||||
X_inv = torch.matmul(U * S_inv.unsqueeze(-2), U.transpose(-2, -1))
|
||||
# Make sure X_inv is exactly symmetric (the input is assumed symmetric)
|
||||
X_inv = 0.5 * (X_inv + X_inv.transpose(-2, -1))
|
||||
return X_inv
|
||||
|
||||
|
||||
def _update_grad_cov(self,
|
||||
group: dict,
|
||||
@ -1949,14 +2013,38 @@ def _test_svd():
|
||||
X2 = torch.matmul(U*S, V.t())
|
||||
assert torch.allclose(X2, X, atol=0.001)
|
||||
|
||||
def _test_safe_inverse():
|
||||
for dim in [10, 20, 100]:
|
||||
a = torch.randn(2, 3, dim, dim//2)
|
||||
a = torch.matmul(a, a.transpose(-2, -1))
|
||||
min_eig = 0.001
|
||||
_diag(a).add_(min_eig)
|
||||
for tolerance in [0.9, 0.99, 1.0, 1.01, 1.1]:
|
||||
self = None
|
||||
inv = PrAdam._safe_inverse(self, a, min_eig, tolerance)
|
||||
prod = (torch.matmul(inv, a))
|
||||
err = prod - torch.eye(dim)
|
||||
err = torch.sum((err**2), dim=(2,3)).sqrt()
|
||||
logging.info(f"_test_safe_inverse(): dim={dim}, tolerance={tolerance}, err={err}")
|
||||
|
||||
def _test_smooth_cov():
|
||||
b = (torch.arange(10)*0.5).exp().diag()
|
||||
b = b.unsqueeze(0).unsqueeze(0)
|
||||
c = PrAdam._smooth_cov(None, b, 0.0, 10.0)
|
||||
logging.info(f"c[noinv] = {_diag(c)}")
|
||||
c = PrAdam._smooth_cov(None, b, 0.0, 10.0, 1.00001)
|
||||
logging.info(f"c[svd] = {_diag(c)}")
|
||||
def _test_apply_min_max():
|
||||
dim = 50
|
||||
X = torch.randn(1, 1, dim, dim*2)
|
||||
X = torch.matmul(X, X.transpose(2, 3))
|
||||
M = torch.randn(1, 1, dim, dim*2)
|
||||
M = torch.matmul(M, M.transpose(2, 3))
|
||||
|
||||
for min_eig in [0.0, 0.1]:
|
||||
for max_eig in [5.0, 10.0]:
|
||||
for Xscale in [1.0, 2.0]:
|
||||
for Mscale in [1.0, 2.0]:
|
||||
self = None
|
||||
thisM = M * Mscale
|
||||
thisMinv = thisM.inverse()
|
||||
Y = PrAdam._apply_min_max_with_metric(self, X*Xscale, thisM, thisMinv,
|
||||
min_eig, max_eig)
|
||||
Y /= _mean(_diag(Y), exclude_dims=[0], keepdim=True)
|
||||
logging.info(f"min_eig={min_eig},max_eig={max_eig},Xscale={Xscale},Mscale={Mscale}, rms of Y = {(Y**2).mean().item()}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
@ -1964,8 +2052,9 @@ if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
import subprocess
|
||||
s = subprocess.check_output("git status -uno .; git log -1", shell=True)
|
||||
_test_smooth_cov()
|
||||
logging.info(s)
|
||||
_test_safe_inverse()
|
||||
_test_apply_min_max()
|
||||
#_test_svd()
|
||||
import sys
|
||||
if len(sys.argv) > 1:
|
||||
|
Loading…
x
Reference in New Issue
Block a user