mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Merge branch 'pradam_exp1l2' into pradam_exp1m2
This commit is contained in:
commit
ca28f46f75
@ -894,6 +894,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
#
|
||||
# G_prime is the diagonalized grad_cov, G' above, of shape (batch_size, num_blocks, block_size).
|
||||
U_g, G_prime, _ = _svd(grad_cov)
|
||||
# Recompute G_prime via matrix multiplication, as svd does not seem to produce zeros on
|
||||
# the singular values in places where it should be exactly zero. This is to keep
|
||||
# dimensions with zero grad separate from those with nonzero grad.
|
||||
G_prime = _diag(torch.matmul(U_g.transpose(2,3), torch.matmul(grad_cov, U_g)))
|
||||
|
||||
# P_prime is P' above, which represents param_cov in the basis that diagonalizes G_prime.
|
||||
# It is not smoothed yet.
|
||||
P_prime = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g))
|
||||
@ -913,6 +918,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# CGC = (C^T G' C), it would be torch.matmul(C.transpose(2, 3), torch.matmul(G, C))
|
||||
# if G were formatted as a diagonal matrix, but G is just the diagonal.
|
||||
CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2), C)
|
||||
# make sure it's exactly symmetric, want to make sure SVD is exact w.r.t.
|
||||
# dimensions with zero grad and zero parameters.
|
||||
CGC = 0.5 * (CGC + CGC.transpose(2, 3))
|
||||
|
||||
U, S, _ = _svd(CGC) # Need SVD to compute CGC^{0.5}
|
||||
|
||||
@ -921,7 +929,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# we have Z'^{-1} = X S^{0.5} X^T where X = C^{-T} U
|
||||
X = torch.triangular_solve(U, C, upper=False, transpose=True)[0]
|
||||
Z_prime_inv = torch.matmul(X * S.sqrt().unsqueeze(-2), X.transpose(2, 3))
|
||||
|
||||
# make sure it's exactly symmetric, want to make sure SVD is exact w.r.t.
|
||||
# dimensions with zero grad and zero parameters.
|
||||
Z_prime_inv = 0.5 * (Z_prime_inv + Z_prime_inv.transpose(2, 3))
|
||||
|
||||
if True:
|
||||
def _check_similar(x, y, name):
|
||||
@ -1517,31 +1527,42 @@ def _mean(x: Tensor,
|
||||
|
||||
|
||||
|
||||
def _svd(x: Tensor):
|
||||
def _svd(x: Tensor, recursion_depth: int = 0):
|
||||
# Wrapper for torch svd that catches errors and re-tries (to address bugs in
|
||||
# some versions of torch SVD for CUDA)
|
||||
randU = None
|
||||
for i in range(4):
|
||||
try:
|
||||
U, S, V = x.svd()
|
||||
s = U.sum()
|
||||
if s - s != 0 or (__name__ == '__main__' and random.random() < 0.01):
|
||||
raise RuntimeError(f"svd failed (or random test), sum={s}")
|
||||
if randU is not None:
|
||||
U = torch.matmul(randU.t(), U)
|
||||
return U, S, V # success
|
||||
except:
|
||||
logging.warning(f"svd failed: i={i}, x.shape={tuple(x.shape)}, x.sum()={x.sum()}: likely "
|
||||
f"error in SVD. Will try again after random change.")
|
||||
this_x = x
|
||||
while this_x.ndim > 2:
|
||||
this_x = this_x[0] # get rid of any batch dims
|
||||
U, S, V = torch.randn_like(this_x).svd()
|
||||
x = torch.matmul(U, x)
|
||||
if randU is None:
|
||||
randU = U
|
||||
else:
|
||||
randU = torch.matmul(U, randU)
|
||||
xsum = x.sum()
|
||||
if not (xsum - xsum == 0):
|
||||
raise RuntimeError("svd on inf or nan input")
|
||||
try:
|
||||
U, S, V = x.svd()
|
||||
s = U.sum()
|
||||
if s - s != 0 or (__name__ == '__main__' and random.random() < 0.01):
|
||||
raise RuntimeError(f"svd failed (or random test), sum={s}")
|
||||
return U, S, V # success
|
||||
except:
|
||||
if recursion_depth < 2:
|
||||
logging.warning(f"svd failed: recursion_depth={recursion_depth}, "
|
||||
f"x.shape={tuple(x.shape)}, x.sum()={x.sum()}: likely "
|
||||
f"error in SVD. Will try again after random permutation of rows")
|
||||
# randomly permute the row indexes of the matrices, and retry, hoping this
|
||||
# fixes the issue
|
||||
n = x.shape[-2]
|
||||
randperm = torch.randperm(n, device=x.device)
|
||||
inv_randperm = torch.zeros(n, device=x.device, dtype=torch.int64)
|
||||
inv_randperm[randperm] = torch.arange(n, device=x.device)
|
||||
x = torch.index_select(x, dim=-2, index=randperm)
|
||||
U,S,V = _svd(x, recursion_depth + 1)
|
||||
return torch.index_select(U, dim=-2, index=inv_randperm), S, V
|
||||
elif recursion_depth < 4:
|
||||
logging.warning(f"svd failed after {recursion_depth} tries: x.shape={tuple(x.shape)}, x.sum()={x.sum()}. Will try orthogonal transformation")
|
||||
Urand, _, _ = torch.randn(x.shape[-2], x.shape[-1], device=x.device,
|
||||
dtype=x.dtype).svd()
|
||||
U, S, V = _svd(torch.matmul(Urand, x),
|
||||
recursion_depth + 1)
|
||||
return torch.matmul(Urand.t(), U), S, V
|
||||
else:
|
||||
raise RuntimeError(f"svd failed after {recursion_depth} tries")
|
||||
|
||||
|
||||
|
||||
|
||||
@ -2182,7 +2203,7 @@ def _test_eve_cain():
|
||||
fix_random_seed(42)
|
||||
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
||||
|
||||
hidden_dim = 768
|
||||
hidden_dim = 400
|
||||
m = torch.nn.Sequential(Linear(E, hidden_dim),
|
||||
torch.nn.PReLU(),
|
||||
Linear(hidden_dim, hidden_dim),
|
||||
|
Loading…
x
Reference in New Issue
Block a user