mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Reworked how inverse is done, fixed bug in _apply_min_max_with_metric, regarding how M is normalized.
This commit is contained in:
parent
766bf69a98
commit
dc9133227f
@ -124,12 +124,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
|
|
||||||
(4) is for smoothing the grad covariance used for (2)
|
(4) is for smoothing the grad covariance used for (2)
|
||||||
|
|
||||||
cov_pow: This was mainly added for development and experimentation purposes;
|
|
||||||
it allows you to smooth the parameter covariance matrices at the
|
|
||||||
stages (1), (2), (3) of smoothing mentioned above, and also
|
|
||||||
the gradient covariance matrix used in stage (2) of smoothing. If the
|
|
||||||
1st 3 values are not 1.0 it will cause a measurable speed penalty because it
|
|
||||||
requires SVD. Recommend to leave all these at 1.0.
|
|
||||||
eps: A general-purpose epsilon to prevent division by zero
|
eps: A general-purpose epsilon to prevent division by zero
|
||||||
denom_rel_eps: An epsilon value used to keep the elements of denominator (exp_avg_grad_sqrt.sqrt())
|
denom_rel_eps: An epsilon value used to keep the elements of denominator (exp_avg_grad_sqrt.sqrt())
|
||||||
not too small relative to the mean value; this will limit the speed of update
|
not too small relative to the mean value; this will limit the speed of update
|
||||||
@ -163,9 +157,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
lr=3e-02,
|
lr=3e-02,
|
||||||
betas=(0.9, 0.98),
|
betas=(0.9, 0.98),
|
||||||
size_lr_scale=0.1,
|
size_lr_scale=0.1,
|
||||||
cov_min=(0.025, 0.0025, 0.1, 0.0001),
|
cov_min=(0.025, 0.05, 0.1, 0.0001),
|
||||||
cov_max=(5.0, 20.0, 3.5, 40.0),
|
cov_max=(5.0, 20.0, 3.5, 40.0),
|
||||||
cov_pow=(1.0, 1.0, 1.0, 1.0),
|
|
||||||
param_rms_smooth0=0.4,
|
param_rms_smooth0=0.4,
|
||||||
param_rms_smooth1=0.2,
|
param_rms_smooth1=0.2,
|
||||||
eps=1.0e-08,
|
eps=1.0e-08,
|
||||||
@ -189,7 +182,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
size_lr_scale=size_lr_scale,
|
size_lr_scale=size_lr_scale,
|
||||||
cov_min=cov_min,
|
cov_min=cov_min,
|
||||||
cov_max=cov_max,
|
cov_max=cov_max,
|
||||||
cov_pow=cov_pow,
|
|
||||||
param_rms_smooth0=param_rms_smooth0,
|
param_rms_smooth0=param_rms_smooth0,
|
||||||
param_rms_smooth1=param_rms_smooth1,
|
param_rms_smooth1=param_rms_smooth1,
|
||||||
betas=betas,
|
betas=betas,
|
||||||
@ -763,8 +755,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
|
|
||||||
P = self._smooth_cov(P,
|
P = self._smooth_cov(P,
|
||||||
max(smooth, group["cov_min"][0]),
|
max(smooth, group["cov_min"][0]),
|
||||||
group["cov_max"][0],
|
group["cov_max"][0])
|
||||||
group["cov_pow"][0])
|
|
||||||
#P = P.clone()
|
#P = P.clone()
|
||||||
#P_diag = _diag(P)
|
#P_diag = _diag(P)
|
||||||
#P_diag_mean = _mean(P_diag, exclude_dims=[0], keepdim=True)
|
#P_diag_mean = _mean(P_diag, exclude_dims=[0], keepdim=True)
|
||||||
@ -775,11 +766,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
#G_diag *= 1.005 # ensure invertible.
|
#G_diag *= 1.005 # ensure invertible.
|
||||||
G = self._smooth_cov(G,
|
G = self._smooth_cov(G,
|
||||||
group["cov_min"][3],
|
group["cov_min"][3],
|
||||||
group["cov_max"][3],
|
group["cov_max"][3])
|
||||||
group["cov_pow"][3])
|
|
||||||
|
|
||||||
|
G_inv = self._safe_inverse(G, group["cov_min"][3])
|
||||||
|
|
||||||
P = self._apply_min_max_with_metric(P, G,
|
P = self._apply_min_max_with_metric(P, G, G_inv,
|
||||||
group["cov_min"][1],
|
group["cov_min"][1],
|
||||||
group["cov_max"][1])
|
group["cov_max"][1])
|
||||||
|
|
||||||
@ -787,36 +778,30 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# Apply a 3rd round of smoothing in the canonical basis.
|
# Apply a 3rd round of smoothing in the canonical basis.
|
||||||
P = self._smooth_cov(P,
|
P = self._smooth_cov(P,
|
||||||
group["cov_min"][2],
|
group["cov_min"][2],
|
||||||
group["cov_max"][2],
|
group["cov_max"][2])
|
||||||
group["cov_pow"][2])
|
|
||||||
return P
|
return P
|
||||||
|
|
||||||
def _smooth_cov(self,
|
def _smooth_cov(self,
|
||||||
X: Tensor,
|
X: Tensor,
|
||||||
min_eig: float,
|
min_eig: float,
|
||||||
max_eig: float,
|
max_eig: float) -> Tensor:
|
||||||
power: float = 1.0) -> Tensor:
|
|
||||||
"""
|
"""
|
||||||
Returns a `smoothed` version of a symmetric positive definite covariance matrix
|
Returns a `smoothed` version of a nonzero symmetric positive semidefinite covariance matrix
|
||||||
[with block-diagonal structure, in a batch]. This is done without SVD (which
|
(actually a batch of such covariance matrices, each one with block-diagonal structure).
|
||||||
can be very slow).
|
This is done without SVD or eigenvalue decomposition or inversion, since those things
|
||||||
The eigenvalues L will be transformed as:
|
can be very slow).
|
||||||
|
|
||||||
L = L + min_eig * L.mean() + eps
|
The eigenvalues L will be transformed in a way roughly approximating the following:
|
||||||
L /= L.mean()
|
|
||||||
L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig
|
|
||||||
L = L ** power # need SVD for this, will get rid of the requirement later.
|
|
||||||
L /= L.mean()
|
|
||||||
|
|
||||||
# Note on approximation functions like x^0.75 for smallish x: on wolframalpha, type:
|
L /= (L**2).mean().sqrt() # normalize so RMS eigenvalue is 1.0
|
||||||
# plot x^0.75 and 0.05 + (1.1x - 0.18 x^2 + 0.02 x^3) for x from 0 to 10
|
L = 1 / (1/L + 1/max_eig) # apply soft-min with max_eig as the maximum
|
||||||
# [this starts to diverge after 5 or so]
|
L /= (1.0 - min_eig) * L.mean()
|
||||||
|
L += min_eig
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
min_eig: minimum allowed eigenvalue of returned X
|
min_eig: minimum allowed eigenvalue of returned X
|
||||||
max_eig: maximum allowed eigenvalue of returned X
|
max_eig: maximum allowed eigenvalue of returned X
|
||||||
power: power to take eigenvalues to
|
X: the batch of symmetric positive semidefinite tensors we are smoothing;
|
||||||
X: the batch of symmetric positive definite tensors we are smoothing;
|
|
||||||
of shape (batch_size, num_blocks, block_size, block_size)
|
of shape (batch_size, num_blocks, block_size, block_size)
|
||||||
"""
|
"""
|
||||||
eps = 1.0e-20
|
eps = 1.0e-20
|
||||||
@ -824,7 +809,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
size = X.shape[1] * X.shape[3]
|
size = X.shape[1] * X.shape[3]
|
||||||
def rms_eig(Y):
|
def rms_eig(Y):
|
||||||
# rms of eigenvalues, or spectral 2-norm.
|
# rms of eigenvalues, or spectral 2-norm.
|
||||||
return (_sum(Y**2, exclude_dims=[0], keepdim=True) / size).sqrt()
|
return (_sum(Y**2, exclude_dims=[0], keepdim=True) / size + eps).sqrt()
|
||||||
|
|
||||||
X /= rms_eig(X)
|
X /= rms_eig(X)
|
||||||
# eig_ceil is the maximum possible eigenvalue that X could possibly
|
# eig_ceil is the maximum possible eigenvalue that X could possibly
|
||||||
@ -860,16 +845,18 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# normalize to have rms eig == 1.
|
# normalize to have rms eig == 1.
|
||||||
X = 0.5 * (X + X.transpose(2, 3)) # make sure exactly symmetric.
|
X = 0.5 * (X + X.transpose(2, 3)) # make sure exactly symmetric.
|
||||||
|
|
||||||
diag = _diag(X).unsqueeze(-1) # Aliased with X
|
def mean_eig(Y):
|
||||||
diag += (_mean(diag, exclude_dims=[0], keepdim=True) * min_eig + eps)
|
return _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1)
|
||||||
X /= rms_eig(X)
|
|
||||||
|
|
||||||
|
X *= ((1.0 - min_eig) / (mean_eig(X) + eps))
|
||||||
|
_diag(X).add_(min_eig)
|
||||||
return X
|
return X
|
||||||
|
|
||||||
|
|
||||||
def _apply_min_max_with_metric(self,
|
def _apply_min_max_with_metric(self,
|
||||||
X: Tensor,
|
X: Tensor,
|
||||||
M: Tensor,
|
M: Tensor,
|
||||||
|
M_inverse: Tensor,
|
||||||
min_eig: float,
|
min_eig: float,
|
||||||
max_eig: float) -> Tensor:
|
max_eig: float) -> Tensor:
|
||||||
"""
|
"""
|
||||||
@ -882,6 +869,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
Args:
|
Args:
|
||||||
X: Batch of block-diagonal nonzero positive definite matrices to have max eigenvalue applied
|
X: Batch of block-diagonal nonzero positive definite matrices to have max eigenvalue applied
|
||||||
M: Batch of positive definite block-diagonal matrices to use as metrics w.r.t. X.
|
M: Batch of positive definite block-diagonal matrices to use as metrics w.r.t. X.
|
||||||
|
M_inverse: If min_eig != 0.0, the inverse of M should be supplied by the user;
|
||||||
|
this is expected to have about the same value that M.inverse() would have,
|
||||||
|
but is supplied separately as the the calling code inverts in a safer-than-normal
|
||||||
|
way based on knowledge of the minimum eigenvalue of M.
|
||||||
|
min_eig: The minimum eigenvalue of the returned Tensor, which will have average
|
||||||
|
eigenvalue equal to 1.0 for each block-diagonal matrix.
|
||||||
|
max_eig: The maximum allowed eigenvalue of the returned Tensor; this is expressed
|
||||||
|
as a multiple of the root-mean-square (rms) eigenvalue of X, and we
|
||||||
|
later normalize and apply min_eig, so the returned tensor may have
|
||||||
|
eigenvalues larger than this.
|
||||||
"""
|
"""
|
||||||
X = X.clone()
|
X = X.clone()
|
||||||
# size of the block-diagonal matrix..
|
# size of the block-diagonal matrix..
|
||||||
@ -894,7 +891,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
|
|
||||||
def rms_eig(Y):
|
def rms_eig(Y):
|
||||||
# rms of eigenvalues, or spectral 2-norm.
|
# rms of eigenvalues, or spectral 2-norm.
|
||||||
return (_sum(Y**2, exclude_dims=[0], keepdim=True) / size).sqrt()
|
return (_sum(Y*Y.transpose(2,3), exclude_dims=[0], keepdim=True) / size).sqrt()
|
||||||
|
|
||||||
# make sure eigs of M^{0.5} X M^{0.5} have rms value 1. This imposes limit on the max.
|
# make sure eigs of M^{0.5} X M^{0.5} have rms value 1. This imposes limit on the max.
|
||||||
X /= rms_eig(torch.matmul(X, M))
|
X /= rms_eig(torch.matmul(X, M))
|
||||||
@ -931,12 +928,79 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
|
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
|
||||||
|
|
||||||
if min_eig != 0.0:
|
if min_eig != 0.0:
|
||||||
X /= mean_eig(X)
|
mean_XM_eig = _sum(X * M.transpose(2, 3), exclude_dims=[0], keepdim=True) / size
|
||||||
X = X * (1.0-min_eig) + min_eig * M.inverse()
|
X /= mean_XM_eig
|
||||||
|
X = X * (1.0-min_eig) + min_eig * M_inverse
|
||||||
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
|
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
|
||||||
|
|
||||||
return X
|
return X
|
||||||
|
|
||||||
|
def _safe_inverse(self,
|
||||||
|
X: Tensor,
|
||||||
|
min_eig: float,
|
||||||
|
tolerance: float = 1.2):
|
||||||
|
"""
|
||||||
|
Returns X.inverse(), but in a safe way that limits the possible amount that numerical
|
||||||
|
roundoff could make the result inaccurate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: batch of positive definite tensors to be inverted, of shape (*, size, size).
|
||||||
|
Mathematically, X must be symmetric and positive definite (SPD), but this
|
||||||
|
routine allows it to have negative eigenvalues and just assumes that
|
||||||
|
any eigenvalues <min_eig should actually equal min_eig.
|
||||||
|
min_eig: the minimum possible eigenvalue that X is known mathematically to have
|
||||||
|
(e.g. 0.1).
|
||||||
|
tolerance: A tolerance that should be >1.0; it controls the tradeoff betweeen
|
||||||
|
redundant computation and accuracy.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
size = X.shape[-1]
|
||||||
|
I = torch.eye(size, dtype=X.dtype, device=X.device)
|
||||||
|
if True:
|
||||||
|
X_inv = X.inverse()
|
||||||
|
else:
|
||||||
|
C = X.cholesky() # X = C C^T
|
||||||
|
I = I.reshape(*([1]*(C.ndim-2)), size, size).expand(*X.shape)
|
||||||
|
|
||||||
|
# note, X^{-1} == C^{-T} C^{-1}
|
||||||
|
# next line does: X_inv = C^{-1} X_inv (== C^{-1})
|
||||||
|
X_inv = torch.triangular_solve(I, C, upper=False, transpose=False)[0]
|
||||||
|
# next line does: X_inv = C^{-T} X_inv (== C^{-T} C^{-1})
|
||||||
|
X_inv = torch.triangular_solve(X_inv, C, upper=False, transpose=True)[0]
|
||||||
|
|
||||||
|
# Make sure X_inv is exactly symmetric (the input is assumed symmetric)
|
||||||
|
X_inv = 0.5 * (X_inv + X_inv.transpose(-2, -1))
|
||||||
|
|
||||||
|
# OK, X_inv is positive definite and we require that its largest eigenvalue
|
||||||
|
# should be < (1/min_eig) * tolerance. That requires that
|
||||||
|
# I * tolerance - X_inv should be positive definite.
|
||||||
|
X_inv_check = ((tolerance / min_eig) * I - X_inv)
|
||||||
|
# now X_inv_check == I * tolerance - X_inv.
|
||||||
|
# We don't need the result of X_inv_check.cholesky() below, we just execute the
|
||||||
|
# statement as a way of checking that X_inv_check is positive definite.
|
||||||
|
# we also check that X_inv has no negative eigenvalues, because if any of the
|
||||||
|
# original X's eigenvalues had been negative, they would show up as negative
|
||||||
|
# eigenvalues in the answer.
|
||||||
|
torch.stack([X_inv, X_inv_check]).cholesky()
|
||||||
|
return X_inv
|
||||||
|
except RuntimeError as e:
|
||||||
|
# _svd(X) is an alternative to X.svd() that retries in response to
|
||||||
|
# failures caused by SVD implementation bugs.
|
||||||
|
U,S,V = _svd(X)
|
||||||
|
S_sign = (2*((U * V).sum(dim=-2) > 0) - 1)
|
||||||
|
# the method with S_with_sign below is a quick but not 100%
|
||||||
|
# guaranteed to find the eigenvalues of a SPD matrix (might not work
|
||||||
|
# if there are tied eigenvalues); this does not matter as we are only
|
||||||
|
# using it to print diagnostics.
|
||||||
|
S_with_sign = S * S_sign
|
||||||
|
logging.warning(f"Caught error in _safe_inverse() (this is normal). Shape={X.shape}, minimum eigenvalue "
|
||||||
|
f"is {S_with_sign.min().item()} vs. {min_eig} with tolerance {tolerance}; exception was: {e}")
|
||||||
|
S_inv = 1.0 / S.clamp(min=min_eig)
|
||||||
|
X_inv = torch.matmul(U * S_inv.unsqueeze(-2), U.transpose(-2, -1))
|
||||||
|
# Make sure X_inv is exactly symmetric (the input is assumed symmetric)
|
||||||
|
X_inv = 0.5 * (X_inv + X_inv.transpose(-2, -1))
|
||||||
|
return X_inv
|
||||||
|
|
||||||
|
|
||||||
def _update_grad_cov(self,
|
def _update_grad_cov(self,
|
||||||
group: dict,
|
group: dict,
|
||||||
@ -1949,14 +2013,38 @@ def _test_svd():
|
|||||||
X2 = torch.matmul(U*S, V.t())
|
X2 = torch.matmul(U*S, V.t())
|
||||||
assert torch.allclose(X2, X, atol=0.001)
|
assert torch.allclose(X2, X, atol=0.001)
|
||||||
|
|
||||||
|
def _test_safe_inverse():
|
||||||
|
for dim in [10, 20, 100]:
|
||||||
|
a = torch.randn(2, 3, dim, dim//2)
|
||||||
|
a = torch.matmul(a, a.transpose(-2, -1))
|
||||||
|
min_eig = 0.001
|
||||||
|
_diag(a).add_(min_eig)
|
||||||
|
for tolerance in [0.9, 0.99, 1.0, 1.01, 1.1]:
|
||||||
|
self = None
|
||||||
|
inv = PrAdam._safe_inverse(self, a, min_eig, tolerance)
|
||||||
|
prod = (torch.matmul(inv, a))
|
||||||
|
err = prod - torch.eye(dim)
|
||||||
|
err = torch.sum((err**2), dim=(2,3)).sqrt()
|
||||||
|
logging.info(f"_test_safe_inverse(): dim={dim}, tolerance={tolerance}, err={err}")
|
||||||
|
|
||||||
def _test_smooth_cov():
|
def _test_apply_min_max():
|
||||||
b = (torch.arange(10)*0.5).exp().diag()
|
dim = 50
|
||||||
b = b.unsqueeze(0).unsqueeze(0)
|
X = torch.randn(1, 1, dim, dim*2)
|
||||||
c = PrAdam._smooth_cov(None, b, 0.0, 10.0)
|
X = torch.matmul(X, X.transpose(2, 3))
|
||||||
logging.info(f"c[noinv] = {_diag(c)}")
|
M = torch.randn(1, 1, dim, dim*2)
|
||||||
c = PrAdam._smooth_cov(None, b, 0.0, 10.0, 1.00001)
|
M = torch.matmul(M, M.transpose(2, 3))
|
||||||
logging.info(f"c[svd] = {_diag(c)}")
|
|
||||||
|
for min_eig in [0.0, 0.1]:
|
||||||
|
for max_eig in [5.0, 10.0]:
|
||||||
|
for Xscale in [1.0, 2.0]:
|
||||||
|
for Mscale in [1.0, 2.0]:
|
||||||
|
self = None
|
||||||
|
thisM = M * Mscale
|
||||||
|
thisMinv = thisM.inverse()
|
||||||
|
Y = PrAdam._apply_min_max_with_metric(self, X*Xscale, thisM, thisMinv,
|
||||||
|
min_eig, max_eig)
|
||||||
|
Y /= _mean(_diag(Y), exclude_dims=[0], keepdim=True)
|
||||||
|
logging.info(f"min_eig={min_eig},max_eig={max_eig},Xscale={Xscale},Mscale={Mscale}, rms of Y = {(Y**2).mean().item()}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
@ -1964,8 +2052,9 @@ if __name__ == "__main__":
|
|||||||
logging.getLogger().setLevel(logging.INFO)
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
import subprocess
|
import subprocess
|
||||||
s = subprocess.check_output("git status -uno .; git log -1", shell=True)
|
s = subprocess.check_output("git status -uno .; git log -1", shell=True)
|
||||||
_test_smooth_cov()
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
_test_safe_inverse()
|
||||||
|
_test_apply_min_max()
|
||||||
#_test_svd()
|
#_test_svd()
|
||||||
import sys
|
import sys
|
||||||
if len(sys.argv) > 1:
|
if len(sys.argv) > 1:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user