mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 00:24:19 +00:00
Add some debugging/diagnostic code
This commit is contained in:
parent
747960677e
commit
67c402a369
@ -122,6 +122,10 @@ def _update_factorization(x: Tensor, x_factorized: Tensor,
|
|||||||
this_mean = _mean_like(x_norm_var, shape)
|
this_mean = _mean_like(x_norm_var, shape)
|
||||||
f = ((1.0 - speed) + speed * this_mean)
|
f = ((1.0 - speed) + speed * this_mean)
|
||||||
factors.append(f)
|
factors.append(f)
|
||||||
|
# temp
|
||||||
|
#import random
|
||||||
|
#if random.random() < 0.1:
|
||||||
|
# print("factor norms: ", list((x-1.0).abs().mean().item() for x in factors))
|
||||||
x_factorized *= _product(*factors)
|
x_factorized *= _product(*factors)
|
||||||
# TEMP
|
# TEMP
|
||||||
#import random
|
#import random
|
||||||
@ -149,8 +153,6 @@ def _get_factor_grads(x: Tensor, x_grad: Tensor) -> List[Tensor]:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _init_factors(x: Tensor, eps: float = 1.0e-04) -> Tuple[Tensor]:
|
def _init_factors(x: Tensor, eps: float = 1.0e-04) -> Tuple[Tensor]:
|
||||||
"""
|
"""
|
||||||
Initialize some factors which we will use to normalize the variance of x.
|
Initialize some factors which we will use to normalize the variance of x.
|
||||||
@ -349,16 +351,15 @@ class Abel(Optimizer):
|
|||||||
|
|
||||||
if step < 10 or step % 10 == 1:
|
if step < 10 or step % 10 == 1:
|
||||||
# do this only every 10 steps, to save time.
|
# do this only every 10 steps, to save time.
|
||||||
num_factors = len(factors_exp_avg_sq)
|
|
||||||
_update_factorization(p, factorization,
|
_update_factorization(p, factorization,
|
||||||
speed=0.1,
|
speed=0.1,
|
||||||
eps=eps)
|
eps=eps)
|
||||||
|
|
||||||
|
|
||||||
factors_sum = None
|
factors_sum = None
|
||||||
for g, e in zip(factor_grads, factors_exp_avg_sq):
|
for g, e in zip(factor_grads, factors_exp_avg_sq):
|
||||||
update_exp_avg_sq(g, e)
|
update_exp_avg_sq(g, e)
|
||||||
this_denom = (e + eps*eps).sqrt()
|
this_denom = (e/bias_correction2 + eps*eps).sqrt()
|
||||||
|
assert g.shape == this_denom.shape
|
||||||
factor_delta = g / this_denom
|
factor_delta = g / this_denom
|
||||||
factors_sum = (factor_delta if factors_sum is None
|
factors_sum = (factor_delta if factors_sum is None
|
||||||
else factors_sum + factor_delta)
|
else factors_sum + factor_delta)
|
||||||
@ -395,16 +396,12 @@ class Abel(Optimizer):
|
|||||||
# `p * factors_sum` is the contribution from changes in x_factor1
|
# `p * factors_sum` is the contribution from changes in x_factor1
|
||||||
# and x_factor2: again, before taking into account the learning
|
# and x_factor2: again, before taking into account the learning
|
||||||
# rate or momentum.
|
# rate or momentum.
|
||||||
|
|
||||||
|
|
||||||
this_delta = ((grad * factorization / denom) + p * factors_sum)
|
this_delta = ((grad * factorization / denom) + p * factors_sum)
|
||||||
|
|
||||||
|
|
||||||
# compute the moving-average change in parameters, and add it to p.
|
# compute the moving-average change in parameters, and add it to p.
|
||||||
delta.mul_(beta1).add_(this_delta, alpha=-lr*(1-beta1))
|
delta.mul_(beta1).add_(this_delta, alpha=-lr*(1-beta1))
|
||||||
|
|
||||||
if step % 50 == 0 and False:
|
|
||||||
print("This_delta norm = ", delta.norm())
|
|
||||||
p.add_(delta)
|
p.add_(delta)
|
||||||
|
|
||||||
|
|
||||||
@ -745,21 +742,27 @@ def _test_abel():
|
|||||||
B = 4
|
B = 4
|
||||||
T = 2
|
T = 2
|
||||||
print("in test_abel")
|
print("in test_abel")
|
||||||
for iter in [0,1]:
|
|
||||||
device = torch.device('cuda')
|
device = torch.device('cuda')
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
|
# these input_magnitudes and output_magnitudes are to test that
|
||||||
|
# Abel is working as we expect and is able to adjust scales of
|
||||||
|
# different dims differently.
|
||||||
|
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||||
|
output_magnitudes = (0.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||||
|
|
||||||
|
for iter in [0,1]:
|
||||||
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
||||||
m = torch.nn.Sequential(Linear(E, 200),
|
m = torch.nn.Sequential(Linear(E, 200),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
Linear(200, E)).to(device)
|
Linear(200, E)).to(device)
|
||||||
|
|
||||||
|
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
|
||||||
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype),
|
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ]
|
||||||
torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ]
|
|
||||||
|
|
||||||
if iter == 0: optim = Abel(m.parameters(), lr=0.003)
|
if iter == 0: optim = Abel(m.parameters(), lr=0.003)
|
||||||
else: optim = Eve(m.parameters(), lr=0.003)
|
else: optim = Eve(m.parameters(), lr=0.003)
|
||||||
scheduler = Eden(optim, lr_batches=300, lr_epochs=2, verbose=False)
|
scheduler = Eden(optim, lr_batches=300, lr_epochs=20, verbose=False)
|
||||||
|
|
||||||
start = timeit.default_timer()
|
start = timeit.default_timer()
|
||||||
for epoch in range(150):
|
for epoch in range(150):
|
||||||
@ -767,7 +770,7 @@ def _test_abel():
|
|||||||
for n, (x,y) in enumerate(train_pairs):
|
for n, (x,y) in enumerate(train_pairs):
|
||||||
y_out = m(x)
|
y_out = m(x)
|
||||||
loss = ((y_out - y)**2).mean() * 100.0
|
loss = ((y_out - y)**2).mean() * 100.0
|
||||||
if n % 10 == 0 and epoch % 10 == 0:
|
if n == 0 and epoch % 10 == 0:
|
||||||
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
||||||
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
||||||
norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
||||||
@ -777,7 +780,7 @@ def _test_abel():
|
|||||||
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
||||||
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
||||||
print(f"Iter {iter}, epoch {epoch}, batch {n}, loss {loss.item()}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
print(f"Iter {iter}, epoch {epoch}, batch {n}, loss {loss.item()}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
||||||
loss.backward()
|
loss.log().backward()
|
||||||
optim.step()
|
optim.step()
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
scheduler.step_batch()
|
scheduler.step_batch()
|
||||||
@ -786,10 +789,22 @@ def _test_abel():
|
|||||||
print(f"Iter={iter}, Time taken: {stop - start}")
|
print(f"Iter={iter}, Time taken: {stop - start}")
|
||||||
|
|
||||||
print("last lr = ", scheduler.get_last_lr())
|
print("last lr = ", scheduler.get_last_lr())
|
||||||
print("state dict = ", scheduler.state_dict())
|
#print("state dict = ", scheduler.state_dict())
|
||||||
|
#print("optim state_dict = ", optim.state_dict())
|
||||||
|
print("input_magnitudes = ", input_magnitudes)
|
||||||
|
print("output_magnitudes = ", output_magnitudes)
|
||||||
|
|
||||||
|
def stddev(x):
|
||||||
|
return ((x-x.mean())**2).mean().sqrt()
|
||||||
|
print("Un-normalized input col magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=0).sqrt().log()))
|
||||||
|
print("Normalized input col magnitudes log-stddev: ", stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log()))
|
||||||
|
|
||||||
|
print("Un-normalized 0-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log()))
|
||||||
|
print("Un-normalized 2-output col magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=0).sqrt().log()))
|
||||||
|
print("Un-normalized 2-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log()))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
_test_abel()
|
_test_abel()
|
||||||
_test_eden()
|
#_test_eden()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user