mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Merge branch 'pradam_exp1l2' into pradam_exp1m2
This commit is contained in:
commit
ca28f46f75
@ -894,6 +894,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
#
|
#
|
||||||
# G_prime is the diagonalized grad_cov, G' above, of shape (batch_size, num_blocks, block_size).
|
# G_prime is the diagonalized grad_cov, G' above, of shape (batch_size, num_blocks, block_size).
|
||||||
U_g, G_prime, _ = _svd(grad_cov)
|
U_g, G_prime, _ = _svd(grad_cov)
|
||||||
|
# Recompute G_prime via matrix multiplication, as svd does not seem to produce zeros on
|
||||||
|
# the singular values in places where it should be exactly zero. This is to keep
|
||||||
|
# dimensions with zero grad separate from those with nonzero grad.
|
||||||
|
G_prime = _diag(torch.matmul(U_g.transpose(2,3), torch.matmul(grad_cov, U_g)))
|
||||||
|
|
||||||
# P_prime is P' above, which represents param_cov in the basis that diagonalizes G_prime.
|
# P_prime is P' above, which represents param_cov in the basis that diagonalizes G_prime.
|
||||||
# It is not smoothed yet.
|
# It is not smoothed yet.
|
||||||
P_prime = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g))
|
P_prime = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g))
|
||||||
@ -913,6 +918,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# CGC = (C^T G' C), it would be torch.matmul(C.transpose(2, 3), torch.matmul(G, C))
|
# CGC = (C^T G' C), it would be torch.matmul(C.transpose(2, 3), torch.matmul(G, C))
|
||||||
# if G were formatted as a diagonal matrix, but G is just the diagonal.
|
# if G were formatted as a diagonal matrix, but G is just the diagonal.
|
||||||
CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2), C)
|
CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2), C)
|
||||||
|
# make sure it's exactly symmetric, want to make sure SVD is exact w.r.t.
|
||||||
|
# dimensions with zero grad and zero parameters.
|
||||||
|
CGC = 0.5 * (CGC + CGC.transpose(2, 3))
|
||||||
|
|
||||||
U, S, _ = _svd(CGC) # Need SVD to compute CGC^{0.5}
|
U, S, _ = _svd(CGC) # Need SVD to compute CGC^{0.5}
|
||||||
|
|
||||||
@ -921,7 +929,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# we have Z'^{-1} = X S^{0.5} X^T where X = C^{-T} U
|
# we have Z'^{-1} = X S^{0.5} X^T where X = C^{-T} U
|
||||||
X = torch.triangular_solve(U, C, upper=False, transpose=True)[0]
|
X = torch.triangular_solve(U, C, upper=False, transpose=True)[0]
|
||||||
Z_prime_inv = torch.matmul(X * S.sqrt().unsqueeze(-2), X.transpose(2, 3))
|
Z_prime_inv = torch.matmul(X * S.sqrt().unsqueeze(-2), X.transpose(2, 3))
|
||||||
|
# make sure it's exactly symmetric, want to make sure SVD is exact w.r.t.
|
||||||
|
# dimensions with zero grad and zero parameters.
|
||||||
|
Z_prime_inv = 0.5 * (Z_prime_inv + Z_prime_inv.transpose(2, 3))
|
||||||
|
|
||||||
if True:
|
if True:
|
||||||
def _check_similar(x, y, name):
|
def _check_similar(x, y, name):
|
||||||
@ -1517,31 +1527,42 @@ def _mean(x: Tensor,
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _svd(x: Tensor):
|
def _svd(x: Tensor, recursion_depth: int = 0):
|
||||||
# Wrapper for torch svd that catches errors and re-tries (to address bugs in
|
# Wrapper for torch svd that catches errors and re-tries (to address bugs in
|
||||||
# some versions of torch SVD for CUDA)
|
# some versions of torch SVD for CUDA)
|
||||||
randU = None
|
xsum = x.sum()
|
||||||
for i in range(4):
|
if not (xsum - xsum == 0):
|
||||||
try:
|
raise RuntimeError("svd on inf or nan input")
|
||||||
U, S, V = x.svd()
|
try:
|
||||||
s = U.sum()
|
U, S, V = x.svd()
|
||||||
if s - s != 0 or (__name__ == '__main__' and random.random() < 0.01):
|
s = U.sum()
|
||||||
raise RuntimeError(f"svd failed (or random test), sum={s}")
|
if s - s != 0 or (__name__ == '__main__' and random.random() < 0.01):
|
||||||
if randU is not None:
|
raise RuntimeError(f"svd failed (or random test), sum={s}")
|
||||||
U = torch.matmul(randU.t(), U)
|
return U, S, V # success
|
||||||
return U, S, V # success
|
except:
|
||||||
except:
|
if recursion_depth < 2:
|
||||||
logging.warning(f"svd failed: i={i}, x.shape={tuple(x.shape)}, x.sum()={x.sum()}: likely "
|
logging.warning(f"svd failed: recursion_depth={recursion_depth}, "
|
||||||
f"error in SVD. Will try again after random change.")
|
f"x.shape={tuple(x.shape)}, x.sum()={x.sum()}: likely "
|
||||||
this_x = x
|
f"error in SVD. Will try again after random permutation of rows")
|
||||||
while this_x.ndim > 2:
|
# randomly permute the row indexes of the matrices, and retry, hoping this
|
||||||
this_x = this_x[0] # get rid of any batch dims
|
# fixes the issue
|
||||||
U, S, V = torch.randn_like(this_x).svd()
|
n = x.shape[-2]
|
||||||
x = torch.matmul(U, x)
|
randperm = torch.randperm(n, device=x.device)
|
||||||
if randU is None:
|
inv_randperm = torch.zeros(n, device=x.device, dtype=torch.int64)
|
||||||
randU = U
|
inv_randperm[randperm] = torch.arange(n, device=x.device)
|
||||||
else:
|
x = torch.index_select(x, dim=-2, index=randperm)
|
||||||
randU = torch.matmul(U, randU)
|
U,S,V = _svd(x, recursion_depth + 1)
|
||||||
|
return torch.index_select(U, dim=-2, index=inv_randperm), S, V
|
||||||
|
elif recursion_depth < 4:
|
||||||
|
logging.warning(f"svd failed after {recursion_depth} tries: x.shape={tuple(x.shape)}, x.sum()={x.sum()}. Will try orthogonal transformation")
|
||||||
|
Urand, _, _ = torch.randn(x.shape[-2], x.shape[-1], device=x.device,
|
||||||
|
dtype=x.dtype).svd()
|
||||||
|
U, S, V = _svd(torch.matmul(Urand, x),
|
||||||
|
recursion_depth + 1)
|
||||||
|
return torch.matmul(Urand.t(), U), S, V
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"svd failed after {recursion_depth} tries")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -2182,7 +2203,7 @@ def _test_eve_cain():
|
|||||||
fix_random_seed(42)
|
fix_random_seed(42)
|
||||||
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
||||||
|
|
||||||
hidden_dim = 768
|
hidden_dim = 400
|
||||||
m = torch.nn.Sequential(Linear(E, hidden_dim),
|
m = torch.nn.Sequential(Linear(E, hidden_dim),
|
||||||
torch.nn.PReLU(),
|
torch.nn.PReLU(),
|
||||||
Linear(hidden_dim, hidden_dim),
|
Linear(hidden_dim, hidden_dim),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user