Merge branch 'pradam_exp1l2' into pradam_exp1m2

This commit is contained in:
Daniel Povey 2022-07-29 15:16:10 +08:00
commit ca28f46f75

View File

@ -894,6 +894,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# #
# G_prime is the diagonalized grad_cov, G' above, of shape (batch_size, num_blocks, block_size). # G_prime is the diagonalized grad_cov, G' above, of shape (batch_size, num_blocks, block_size).
U_g, G_prime, _ = _svd(grad_cov) U_g, G_prime, _ = _svd(grad_cov)
# Recompute G_prime via matrix multiplication, as svd does not seem to produce zeros on
# the singular values in places where it should be exactly zero. This is to keep
# dimensions with zero grad separate from those with nonzero grad.
G_prime = _diag(torch.matmul(U_g.transpose(2,3), torch.matmul(grad_cov, U_g)))
# P_prime is P' above, which represents param_cov in the basis that diagonalizes G_prime. # P_prime is P' above, which represents param_cov in the basis that diagonalizes G_prime.
# It is not smoothed yet. # It is not smoothed yet.
P_prime = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g)) P_prime = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g))
@ -913,6 +918,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# CGC = (C^T G' C), it would be torch.matmul(C.transpose(2, 3), torch.matmul(G, C)) # CGC = (C^T G' C), it would be torch.matmul(C.transpose(2, 3), torch.matmul(G, C))
# if G were formatted as a diagonal matrix, but G is just the diagonal. # if G were formatted as a diagonal matrix, but G is just the diagonal.
CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2), C) CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2), C)
# make sure it's exactly symmetric, want to make sure SVD is exact w.r.t.
# dimensions with zero grad and zero parameters.
CGC = 0.5 * (CGC + CGC.transpose(2, 3))
U, S, _ = _svd(CGC) # Need SVD to compute CGC^{0.5} U, S, _ = _svd(CGC) # Need SVD to compute CGC^{0.5}
@ -921,7 +929,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# we have Z'^{-1} = X S^{0.5} X^T where X = C^{-T} U # we have Z'^{-1} = X S^{0.5} X^T where X = C^{-T} U
X = torch.triangular_solve(U, C, upper=False, transpose=True)[0] X = torch.triangular_solve(U, C, upper=False, transpose=True)[0]
Z_prime_inv = torch.matmul(X * S.sqrt().unsqueeze(-2), X.transpose(2, 3)) Z_prime_inv = torch.matmul(X * S.sqrt().unsqueeze(-2), X.transpose(2, 3))
# make sure it's exactly symmetric, want to make sure SVD is exact w.r.t.
# dimensions with zero grad and zero parameters.
Z_prime_inv = 0.5 * (Z_prime_inv + Z_prime_inv.transpose(2, 3))
if True: if True:
def _check_similar(x, y, name): def _check_similar(x, y, name):
@ -1517,31 +1527,42 @@ def _mean(x: Tensor,
def _svd(x: Tensor): def _svd(x: Tensor, recursion_depth: int = 0):
# Wrapper for torch svd that catches errors and re-tries (to address bugs in # Wrapper for torch svd that catches errors and re-tries (to address bugs in
# some versions of torch SVD for CUDA) # some versions of torch SVD for CUDA)
randU = None xsum = x.sum()
for i in range(4): if not (xsum - xsum == 0):
try: raise RuntimeError("svd on inf or nan input")
U, S, V = x.svd() try:
s = U.sum() U, S, V = x.svd()
if s - s != 0 or (__name__ == '__main__' and random.random() < 0.01): s = U.sum()
raise RuntimeError(f"svd failed (or random test), sum={s}") if s - s != 0 or (__name__ == '__main__' and random.random() < 0.01):
if randU is not None: raise RuntimeError(f"svd failed (or random test), sum={s}")
U = torch.matmul(randU.t(), U) return U, S, V # success
return U, S, V # success except:
except: if recursion_depth < 2:
logging.warning(f"svd failed: i={i}, x.shape={tuple(x.shape)}, x.sum()={x.sum()}: likely " logging.warning(f"svd failed: recursion_depth={recursion_depth}, "
f"error in SVD. Will try again after random change.") f"x.shape={tuple(x.shape)}, x.sum()={x.sum()}: likely "
this_x = x f"error in SVD. Will try again after random permutation of rows")
while this_x.ndim > 2: # randomly permute the row indexes of the matrices, and retry, hoping this
this_x = this_x[0] # get rid of any batch dims # fixes the issue
U, S, V = torch.randn_like(this_x).svd() n = x.shape[-2]
x = torch.matmul(U, x) randperm = torch.randperm(n, device=x.device)
if randU is None: inv_randperm = torch.zeros(n, device=x.device, dtype=torch.int64)
randU = U inv_randperm[randperm] = torch.arange(n, device=x.device)
else: x = torch.index_select(x, dim=-2, index=randperm)
randU = torch.matmul(U, randU) U,S,V = _svd(x, recursion_depth + 1)
return torch.index_select(U, dim=-2, index=inv_randperm), S, V
elif recursion_depth < 4:
logging.warning(f"svd failed after {recursion_depth} tries: x.shape={tuple(x.shape)}, x.sum()={x.sum()}. Will try orthogonal transformation")
Urand, _, _ = torch.randn(x.shape[-2], x.shape[-1], device=x.device,
dtype=x.dtype).svd()
U, S, V = _svd(torch.matmul(Urand, x),
recursion_depth + 1)
return torch.matmul(Urand.t(), U), S, V
else:
raise RuntimeError(f"svd failed after {recursion_depth} tries")
@ -2182,7 +2203,7 @@ def _test_eve_cain():
fix_random_seed(42) fix_random_seed(42)
Linear = torch.nn.Linear if iter == 0 else ScaledLinear Linear = torch.nn.Linear if iter == 0 else ScaledLinear
hidden_dim = 768 hidden_dim = 400
m = torch.nn.Sequential(Linear(E, hidden_dim), m = torch.nn.Sequential(Linear(E, hidden_dim),
torch.nn.PReLU(), torch.nn.PReLU(),
Linear(hidden_dim, hidden_dim), Linear(hidden_dim, hidden_dim),