Add some debugging/diagnostic code

This commit is contained in:
Daniel Povey 2022-05-15 13:28:00 +08:00
parent 747960677e
commit 67c402a369

View File

@ -122,6 +122,10 @@ def _update_factorization(x: Tensor, x_factorized: Tensor,
this_mean = _mean_like(x_norm_var, shape)
f = ((1.0 - speed) + speed * this_mean)
factors.append(f)
# temp
#import random
#if random.random() < 0.1:
# print("factor norms: ", list((x-1.0).abs().mean().item() for x in factors))
x_factorized *= _product(*factors)
# TEMP
#import random
@ -149,8 +153,6 @@ def _get_factor_grads(x: Tensor, x_grad: Tensor) -> List[Tensor]:
def _init_factors(x: Tensor, eps: float = 1.0e-04) -> Tuple[Tensor]:
"""
Initialize some factors which we will use to normalize the variance of x.
@ -349,16 +351,15 @@ class Abel(Optimizer):
if step < 10 or step % 10 == 1:
# do this only every 10 steps, to save time.
num_factors = len(factors_exp_avg_sq)
_update_factorization(p, factorization,
speed=0.1,
eps=eps)
factors_sum = None
for g, e in zip(factor_grads, factors_exp_avg_sq):
update_exp_avg_sq(g, e)
this_denom = (e + eps*eps).sqrt()
this_denom = (e/bias_correction2 + eps*eps).sqrt()
assert g.shape == this_denom.shape
factor_delta = g / this_denom
factors_sum = (factor_delta if factors_sum is None
else factors_sum + factor_delta)
@ -395,16 +396,12 @@ class Abel(Optimizer):
# `p * factors_sum` is the contribution from changes in x_factor1
# and x_factor2: again, before taking into account the learning
# rate or momentum.
this_delta = ((grad * factorization / denom) + p * factors_sum)
# compute the moving-average change in parameters, and add it to p.
delta.mul_(beta1).add_(this_delta, alpha=-lr*(1-beta1))
if step % 50 == 0 and False:
print("This_delta norm = ", delta.norm())
p.add_(delta)
@ -745,21 +742,27 @@ def _test_abel():
B = 4
T = 2
print("in test_abel")
device = torch.device('cuda')
dtype = torch.float32
# these input_magnitudes and output_magnitudes are to test that
# Abel is working as we expect and is able to adjust scales of
# different dims differently.
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
output_magnitudes = (0.0 * torch.randn(E, dtype=dtype, device=device)).exp()
for iter in [0,1]:
device = torch.device('cuda')
dtype = torch.float32
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
m = torch.nn.Sequential(Linear(E, 200),
torch.nn.ReLU(),
Linear(200, E)).to(device)
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype),
torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ]
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ]
if iter == 0: optim = Abel(m.parameters(), lr=0.003)
else: optim = Eve(m.parameters(), lr=0.003)
scheduler = Eden(optim, lr_batches=300, lr_epochs=2, verbose=False)
scheduler = Eden(optim, lr_batches=300, lr_epochs=20, verbose=False)
start = timeit.default_timer()
for epoch in range(150):
@ -767,7 +770,7 @@ def _test_abel():
for n, (x,y) in enumerate(train_pairs):
y_out = m(x)
loss = ((y_out - y)**2).mean() * 100.0
if n % 10 == 0 and epoch % 10 == 0:
if n == 0 and epoch % 10 == 0:
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
@ -777,7 +780,7 @@ def _test_abel():
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
print(f"Iter {iter}, epoch {epoch}, batch {n}, loss {loss.item()}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
loss.backward()
loss.log().backward()
optim.step()
optim.zero_grad()
scheduler.step_batch()
@ -786,10 +789,22 @@ def _test_abel():
print(f"Iter={iter}, Time taken: {stop - start}")
print("last lr = ", scheduler.get_last_lr())
print("state dict = ", scheduler.state_dict())
#print("state dict = ", scheduler.state_dict())
#print("optim state_dict = ", optim.state_dict())
print("input_magnitudes = ", input_magnitudes)
print("output_magnitudes = ", output_magnitudes)
def stddev(x):
return ((x-x.mean())**2).mean().sqrt()
print("Un-normalized input col magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=0).sqrt().log()))
print("Normalized input col magnitudes log-stddev: ", stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log()))
print("Un-normalized 0-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log()))
print("Un-normalized 2-output col magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=0).sqrt().log()))
print("Un-normalized 2-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log()))
if __name__ == "__main__":
_test_abel()
_test_eden()
#_test_eden()