mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Add some debugging/diagnostic code
This commit is contained in:
parent
747960677e
commit
67c402a369
@ -122,6 +122,10 @@ def _update_factorization(x: Tensor, x_factorized: Tensor,
|
||||
this_mean = _mean_like(x_norm_var, shape)
|
||||
f = ((1.0 - speed) + speed * this_mean)
|
||||
factors.append(f)
|
||||
# temp
|
||||
#import random
|
||||
#if random.random() < 0.1:
|
||||
# print("factor norms: ", list((x-1.0).abs().mean().item() for x in factors))
|
||||
x_factorized *= _product(*factors)
|
||||
# TEMP
|
||||
#import random
|
||||
@ -149,8 +153,6 @@ def _get_factor_grads(x: Tensor, x_grad: Tensor) -> List[Tensor]:
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _init_factors(x: Tensor, eps: float = 1.0e-04) -> Tuple[Tensor]:
|
||||
"""
|
||||
Initialize some factors which we will use to normalize the variance of x.
|
||||
@ -349,16 +351,15 @@ class Abel(Optimizer):
|
||||
|
||||
if step < 10 or step % 10 == 1:
|
||||
# do this only every 10 steps, to save time.
|
||||
num_factors = len(factors_exp_avg_sq)
|
||||
_update_factorization(p, factorization,
|
||||
speed=0.1,
|
||||
eps=eps)
|
||||
|
||||
|
||||
factors_sum = None
|
||||
for g, e in zip(factor_grads, factors_exp_avg_sq):
|
||||
update_exp_avg_sq(g, e)
|
||||
this_denom = (e + eps*eps).sqrt()
|
||||
this_denom = (e/bias_correction2 + eps*eps).sqrt()
|
||||
assert g.shape == this_denom.shape
|
||||
factor_delta = g / this_denom
|
||||
factors_sum = (factor_delta if factors_sum is None
|
||||
else factors_sum + factor_delta)
|
||||
@ -395,16 +396,12 @@ class Abel(Optimizer):
|
||||
# `p * factors_sum` is the contribution from changes in x_factor1
|
||||
# and x_factor2: again, before taking into account the learning
|
||||
# rate or momentum.
|
||||
|
||||
|
||||
this_delta = ((grad * factorization / denom) + p * factors_sum)
|
||||
|
||||
|
||||
# compute the moving-average change in parameters, and add it to p.
|
||||
delta.mul_(beta1).add_(this_delta, alpha=-lr*(1-beta1))
|
||||
|
||||
if step % 50 == 0 and False:
|
||||
print("This_delta norm = ", delta.norm())
|
||||
p.add_(delta)
|
||||
|
||||
|
||||
@ -745,21 +742,27 @@ def _test_abel():
|
||||
B = 4
|
||||
T = 2
|
||||
print("in test_abel")
|
||||
device = torch.device('cuda')
|
||||
dtype = torch.float32
|
||||
|
||||
# these input_magnitudes and output_magnitudes are to test that
|
||||
# Abel is working as we expect and is able to adjust scales of
|
||||
# different dims differently.
|
||||
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||
output_magnitudes = (0.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||
|
||||
for iter in [0,1]:
|
||||
device = torch.device('cuda')
|
||||
dtype = torch.float32
|
||||
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
||||
m = torch.nn.Sequential(Linear(E, 200),
|
||||
torch.nn.ReLU(),
|
||||
Linear(200, E)).to(device)
|
||||
|
||||
|
||||
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype),
|
||||
torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ]
|
||||
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
|
||||
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ]
|
||||
|
||||
if iter == 0: optim = Abel(m.parameters(), lr=0.003)
|
||||
else: optim = Eve(m.parameters(), lr=0.003)
|
||||
scheduler = Eden(optim, lr_batches=300, lr_epochs=2, verbose=False)
|
||||
scheduler = Eden(optim, lr_batches=300, lr_epochs=20, verbose=False)
|
||||
|
||||
start = timeit.default_timer()
|
||||
for epoch in range(150):
|
||||
@ -767,7 +770,7 @@ def _test_abel():
|
||||
for n, (x,y) in enumerate(train_pairs):
|
||||
y_out = m(x)
|
||||
loss = ((y_out - y)**2).mean() * 100.0
|
||||
if n % 10 == 0 and epoch % 10 == 0:
|
||||
if n == 0 and epoch % 10 == 0:
|
||||
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
||||
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
||||
norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
||||
@ -777,7 +780,7 @@ def _test_abel():
|
||||
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
||||
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
||||
print(f"Iter {iter}, epoch {epoch}, batch {n}, loss {loss.item()}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
||||
loss.backward()
|
||||
loss.log().backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
scheduler.step_batch()
|
||||
@ -786,10 +789,22 @@ def _test_abel():
|
||||
print(f"Iter={iter}, Time taken: {stop - start}")
|
||||
|
||||
print("last lr = ", scheduler.get_last_lr())
|
||||
print("state dict = ", scheduler.state_dict())
|
||||
#print("state dict = ", scheduler.state_dict())
|
||||
#print("optim state_dict = ", optim.state_dict())
|
||||
print("input_magnitudes = ", input_magnitudes)
|
||||
print("output_magnitudes = ", output_magnitudes)
|
||||
|
||||
def stddev(x):
|
||||
return ((x-x.mean())**2).mean().sqrt()
|
||||
print("Un-normalized input col magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=0).sqrt().log()))
|
||||
print("Normalized input col magnitudes log-stddev: ", stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log()))
|
||||
|
||||
print("Un-normalized 0-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log()))
|
||||
print("Un-normalized 2-output col magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=0).sqrt().log()))
|
||||
print("Un-normalized 2-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log()))
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_abel()
|
||||
_test_eden()
|
||||
#_test_eden()
|
||||
|
Loading…
x
Reference in New Issue
Block a user