mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Merge branch 'pradam_exp1m4_nophase1_noinv' into pradam_exp1m4_nophase1_rework_noinv
This commit is contained in:
commit
2042c9862c
@ -856,13 +856,34 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
mean_eig = _mean(diag, exclude_dims=[0], keepdim=True)
|
||||
diag += (mean_eig * min_eig + eps)
|
||||
cur_diag_mean = mean_eig * (1 + min_eig) + eps
|
||||
# The following 2 statements will be equivalent to:
|
||||
# L /= L.mean()
|
||||
# L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig
|
||||
# if L is the eigenvalues
|
||||
X_inv = X.inverse()
|
||||
_diag(X_inv).add_(1. / (max_eig * cur_diag_mean))
|
||||
X = X_inv.inverse()
|
||||
X /= cur_diag_mean.unsqueeze(-1)
|
||||
# OK, now the mean of the diagonal of X is 1 (or less than 1, in
|
||||
# which case X is extremely tiny).
|
||||
|
||||
# eig_ceil is the maximum possible eigenvalue that X could possibly
|
||||
# have at this time.
|
||||
eig_ceil = X.shape[-1]
|
||||
|
||||
while max_eig <= eig_ceil * 0.5:
|
||||
# this is equivalent to the operation: l -> l - 0.5/eig_ceil*l*l
|
||||
# on eigenvalues, which maps eig_ceil to 0.5*eig_ceil and is monotonically
|
||||
# increasing from 0..eig_ceil.
|
||||
logging.info(f"X={X}, eig_ceil={eig_ceil}")
|
||||
X = X - 0.5/eig_ceil * torch.matmul(X, X)
|
||||
eig_ceil = 0.5 * eig_ceil
|
||||
|
||||
# max_eig > eig_ceil * 0.5
|
||||
if max_eig < eig_ceil:
|
||||
# map l -> l - coeff*l*l, if l==eig_ceil, this
|
||||
# takes us to:
|
||||
# eig_ceil - (eig_ceil-max_eig/(eig_ceil*eig_ceil))*eig_ceil*eig_ceil
|
||||
# == max_eig
|
||||
# .. and the fact that coeff <= 0.5/eig_ceil [since max_eig>eig_ceil*0.5]
|
||||
# means that the function is monotonic on inputs from 0 to eig_ceil.
|
||||
coeff = (eig_ceil - max_eig) / (eig_ceil*eig_ceil)
|
||||
X = X - coeff * torch.matmul(X, X)
|
||||
|
||||
# Normalize again.
|
||||
X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1)
|
||||
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
|
||||
return X
|
||||
@ -1120,7 +1141,7 @@ def _mean(x: Tensor,
|
||||
# the interface changes in future.
|
||||
return x.mean(dim=tuple(range(x.ndim)), keepdim=keepdim)
|
||||
elif x.ndim == 1:
|
||||
assert exclude_dim == [0] or exclude_dim == [-1]
|
||||
assert exclude_dims == [0] or exclude_dims == [-1]
|
||||
return x # if one dim is excluded, there are no dims to mean, and
|
||||
# x.mean(dim=[]) means all dims so we should not call mean().
|
||||
exclude_dims_norm = [i if i >= 0 else i + x.ndim for i in exclude_dims]
|
||||
@ -1806,7 +1827,7 @@ def _test_eve_cain():
|
||||
fix_random_seed(42)
|
||||
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
||||
|
||||
hidden_dim = 200
|
||||
hidden_dim = 300
|
||||
m = torch.nn.Sequential(Linear(E, hidden_dim),
|
||||
torch.nn.PReLU(),
|
||||
Linear(hidden_dim, hidden_dim),
|
||||
@ -1823,16 +1844,6 @@ def _test_eve_cain():
|
||||
elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_block_size=256)
|
||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
||||
|
||||
#TEMP
|
||||
if iter == 3:
|
||||
a = torch.randn(5, 10, 5, 5)
|
||||
a = torch.matmul(a, a.transpose(2, 3))
|
||||
b1 = optim._smooth_cov(a, 0.1, 4.0, 0.999999999)
|
||||
b2 = optim._smooth_cov(a, 0.1, 4.0, 1.0)
|
||||
diff = (b1 - b2)
|
||||
ratio = (diff**2).sqrt().mean() / (b1**2).sqrt().mean()
|
||||
logging.info(f"ratio = {ratio}")
|
||||
assert ratio < 0.01
|
||||
|
||||
start = timeit.default_timer()
|
||||
avg_loss = 0.0
|
||||
@ -1891,12 +1902,22 @@ def _test_svd():
|
||||
X2 = torch.matmul(U*S, V.t())
|
||||
assert torch.allclose(X2, X, atol=0.001)
|
||||
|
||||
|
||||
def _test_smooth_cov():
|
||||
b = (torch.arange(10)*0.5).exp().diag()
|
||||
b = b.unsqueeze(0).unsqueeze(0)
|
||||
c = PrAdam._smooth_cov(None, b, 0.0, 10.0)
|
||||
logging.info(f"c[noinv] = {_diag(c)}")
|
||||
c = PrAdam._smooth_cov(None, b, 0.0, 10.0, 1.00001)
|
||||
logging.info(f"c[svd] = {_diag(c)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
import subprocess
|
||||
s = subprocess.check_output("git status -uno .; git log -1", shell=True)
|
||||
_test_smooth_cov()
|
||||
logging.info(s)
|
||||
#_test_svd()
|
||||
_test_eve_cain()
|
||||
|
Loading…
x
Reference in New Issue
Block a user