mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Another bug fix, regarding Q being transposed.
This commit is contained in:
parent
ad2e698fc3
commit
fb36712e6b
@ -60,7 +60,7 @@ class LearnedGradient(Optimizer):
|
||||
params,
|
||||
lr=3e-02,
|
||||
size_lr_scale=0.1,
|
||||
meta_lr_scale=0.1,
|
||||
meta_lr_scale=0.2,
|
||||
betas=(0.9, 0.98),
|
||||
eps=1.0e-08,
|
||||
size_update_period=1,
|
||||
@ -183,7 +183,7 @@ class LearnedGradient(Optimizer):
|
||||
# re-estimate when we "rediagonalize" (this diagonalizes
|
||||
# the gradient covariance).
|
||||
# If our parameter matrix M is of shape (..., size),
|
||||
# we'll view M as being M == torch.matmul(N, q)
|
||||
# we'll view M as being M == torch.matmul(N, Q)
|
||||
# so p can be interpreted as having shape
|
||||
# (size, size), interpreted as being indexed [diagonalized_coordinate, canonical_coordinate]
|
||||
# by "diagonalize" we mean that we diagonalize the gradient covariance.
|
||||
@ -255,6 +255,7 @@ class LearnedGradient(Optimizer):
|
||||
if step % lr_est_period == 0:
|
||||
self._update_lrs(group, p, state)
|
||||
if step % (lr_est_period * diagonalize_period) == 0:
|
||||
logging.info("Diagonalizing")
|
||||
self._diagonalize_lrs(group, p, state)
|
||||
self._zero_exp_avg_sq(state)
|
||||
if True:
|
||||
@ -555,8 +556,8 @@ class LearnedGradient(Optimizer):
|
||||
S_new *= 1.0 / S_new.mean() # normalize so mean is 1.0
|
||||
S_new.clamp_(lr_mat_min, lr_mat_max) # apply limits once more.
|
||||
# Reconstruct Q with the modified S.
|
||||
Q[:] = torch.matmul(U * S, V.t())
|
||||
if random.random() < 0.1:
|
||||
Q[:] = torch.matmul(U * S_new, V.t())
|
||||
if random.random() < 0.03:
|
||||
subsample = max(1, S.numel() // 20)
|
||||
logging.info(f"shape={tuple(p.shape)}, dim={dim}, modifed S from {S[::subsample]} to {S_new[::subsample]}")
|
||||
|
||||
@ -580,7 +581,7 @@ class LearnedGradient(Optimizer):
|
||||
# index).
|
||||
|
||||
grad_cov = state[f"grad_cov_{dim}"]
|
||||
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q))
|
||||
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.t()))
|
||||
N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric
|
||||
U, S, V = N_grad_cov.svd()
|
||||
|
||||
@ -590,7 +591,7 @@ class LearnedGradient(Optimizer):
|
||||
# Now, we can diagonalize N_grad_cov with:
|
||||
# U^T N_grad_cov U == S.
|
||||
# N_grad_cov is a sum of N_grad^T N_grad.
|
||||
# So U^T N_grad^T N_grad U is diagonal.
|
||||
# We know U^T N_grad_cov U is diagonal, so U^T N_grad^T N_grad U is diagonal.
|
||||
# The linearized pseudo-loss can be written as tr(N_grad^T N_grad).
|
||||
# This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I,
|
||||
#
|
||||
@ -601,7 +602,8 @@ class LearnedGradient(Optimizer):
|
||||
# (hat_N means \hat{N}, or N with a hat on it).
|
||||
# So if we interpret hat_N = N U, the gradient covariance w.r.t.
|
||||
# hat_N will be diagonalized. We also modify Q to hat_Q when
|
||||
# we modify hat_N, to keep the product M = N Q = N U U^T Q = hat_N hat_Q
|
||||
# we modify hat_N, to keep the product M unchanged:
|
||||
# M = N Q = N U U^T Q = hat_N hat_Q
|
||||
# This can be done by setting
|
||||
# hat_Q = U^T Q (eq.10)
|
||||
#
|
||||
@ -814,8 +816,10 @@ class LearnedGradient(Optimizer):
|
||||
if x.shape[dim] == 1:
|
||||
continue
|
||||
Q = state[f"Q_{dim}"]
|
||||
if not forward:
|
||||
# Q is indexed [canonical_index, diagonalized_index]
|
||||
if forward:
|
||||
# Q is indexed [diagonalized_index, canonical_index]; in the forward
|
||||
# direction we want to change canonical to diagonalized index, so have
|
||||
# to transpose.
|
||||
Q = Q.t()
|
||||
# TODO: could possibly somehow force the output format to be unchanged.
|
||||
x = x.transpose(-1, dim)
|
||||
@ -1156,7 +1160,7 @@ class LearnedGradient(Optimizer):
|
||||
try:
|
||||
P = 0.5 * (P + P.t())
|
||||
_,s,_ = P.svd()
|
||||
print(f"Min,max eig of P: {s.min()},{s.max()}")
|
||||
logging.info(f"Min,max eig of P: {s.min()},{s.max()}")
|
||||
except:
|
||||
pass
|
||||
# testing... note, this is only true modulo "eps"
|
||||
@ -1167,7 +1171,7 @@ class LearnedGradient(Optimizer):
|
||||
# Roundoff can cause significant differences, so use a fairly large
|
||||
# threshold of 0.001. We may increase this later or even remove the check.
|
||||
if not C_diff.abs().mean() < 0.01 * C_smoothed.diag().mean():
|
||||
print(f"Warning: large C_diff: {C_diff.abs().mean()}, C diag mean: {C_smoothed.diag().mean()}")
|
||||
logging.info(f"Warning: large C_diff: {C_diff.abs().mean()}, C diag mean: {C_smoothed.diag().mean()}")
|
||||
|
||||
return P
|
||||
|
||||
@ -1545,7 +1549,7 @@ class Cain(Optimizer):
|
||||
var_factor = var_factor.mean(dim=dims, keepdim=True)
|
||||
|
||||
#if random.random() < 0.01:
|
||||
# print(f"shape={p.shape}, var_factor={var_factor}")
|
||||
# logging.info(f"shape={p.shape}, var_factor={var_factor}")
|
||||
param_rms.mul_(var_factor.sqrt())
|
||||
|
||||
|
||||
@ -1648,7 +1652,7 @@ class LRScheduler(object):
|
||||
def print_lr(self, is_verbose, group, lr):
|
||||
"""Display the current learning rate."""
|
||||
if is_verbose:
|
||||
print(
|
||||
logging.info(
|
||||
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
||||
f" of group {group} to {lr:.4e}."
|
||||
)
|
||||
@ -1797,7 +1801,7 @@ class Eve(Optimizer):
|
||||
|
||||
if random.random() < 0.0005:
|
||||
step = (exp_avg/denom) * step_size
|
||||
print(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}")
|
||||
logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}")
|
||||
|
||||
|
||||
return loss
|
||||
@ -1863,8 +1867,8 @@ def _test_eden():
|
||||
scheduler.step_batch()
|
||||
optim.zero_grad()
|
||||
|
||||
print("last lr = ", scheduler.get_last_lr())
|
||||
print("state dict = ", scheduler.state_dict())
|
||||
logging.info(f"last lr = {scheduler.get_last_lr()}")
|
||||
logging.info(f"state dict = {scheduler.state_dict()}")
|
||||
|
||||
|
||||
def _test_eve_cain():
|
||||
@ -1873,7 +1877,7 @@ def _test_eve_cain():
|
||||
E = 100
|
||||
B = 4
|
||||
T = 2
|
||||
print("in test_eve_cain")
|
||||
logging.info("in test_eve_cain")
|
||||
device = torch.device('cuda')
|
||||
dtype = torch.float32
|
||||
|
||||
@ -1921,7 +1925,8 @@ def _test_eve_cain():
|
||||
avg_loss = loss.item()
|
||||
else:
|
||||
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
|
||||
if n == 0 and epoch % 5 == 0:
|
||||
#if n == 0 and epoch % 5 == 0:
|
||||
if True:
|
||||
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
||||
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
||||
norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
||||
@ -1931,7 +1936,7 @@ def _test_eve_cain():
|
||||
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
||||
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
||||
lr = scheduler.get_last_lr()[0]
|
||||
print(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
||||
logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
||||
loss.log().backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
@ -1940,24 +1945,24 @@ def _test_eve_cain():
|
||||
#diagnostic.print_diagnostics()
|
||||
|
||||
stop = timeit.default_timer()
|
||||
print(f"Iter={iter}, Time taken: {stop - start}")
|
||||
logging.info(f"Iter={iter}, Time taken: {stop - start}")
|
||||
|
||||
print("last lr = ", scheduler.get_last_lr())
|
||||
#print("state dict = ", scheduler.state_dict())
|
||||
#print("optim state_dict = ", optim.state_dict())
|
||||
print("input_magnitudes = ", input_magnitudes)
|
||||
print("output_magnitudes = ", output_magnitudes)
|
||||
logging.info(f"last lr = {scheduler.get_last_lr()}")
|
||||
#logging.info("state dict = ", scheduler.state_dict())
|
||||
#logging.info("optim state_dict = ", optim.state_dict())
|
||||
logging.info(f"input_magnitudes = {input_magnitudes}")
|
||||
logging.info(f"output_magnitudes = {output_magnitudes}")
|
||||
|
||||
def stddev(x):
|
||||
return ((x-x.mean())**2).mean().sqrt()
|
||||
print("Un-normalized input col magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=0).sqrt().log()))
|
||||
print("Normalized input col magnitudes log-stddev: ", stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log()))
|
||||
logging.info(f"Un-normalized input col magnitudes log-stddev: {stddev((m[0].weight**2).sum(dim=0).sqrt().log())}")
|
||||
logging.info(f"Normalized input col magnitudes log-stddev: {stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log())}")
|
||||
|
||||
print("Un-normalized 0-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log()))
|
||||
print("Un-normalized 2-input col magnitudes log-stddev: ", stddev((m[2].weight**2).sum(dim=0).sqrt().log()))
|
||||
print("Un-normalized 2-output row magnitudes log-stddev: ", stddev((m[2].weight**2).sum(dim=1).sqrt().log()))
|
||||
logging.info(f"Un-normalized 0-output row magnitudes log-stddev: {stddev((m[0].weight**2).sum(dim=1).sqrt().log())}")
|
||||
logging.info("Un-normalized 2-input col magnitudes log-stddev: {stddev((m[2].weight**2).sum(dim=0).sqrt().log())}")
|
||||
logging.info("Un-normalized 2-output row magnitudes log-stddev: {stddev((m[2].weight**2).sum(dim=1).sqrt().log())}")
|
||||
|
||||
print("Normalized output row magnitudes log-stddev: ", stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log()))
|
||||
logging.info("Normalized output row magnitudes log-stddev: {stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log())}")
|
||||
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user