Apparently working version, with changed test-code topology

This commit is contained in:
Daniel Povey 2022-07-11 13:17:29 -07:00
parent 245d39b1bb
commit 7993c84cd6

View File

@ -18,6 +18,7 @@ from collections import defaultdict
from typing import List, Optional, Union, Tuple, List
from lhotse.utils import fix_random_seed
import torch
from scaling import ActivationBalancer
import random
from torch import Tensor
from torch.optim import Optimizer
@ -247,6 +248,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
if numel > 1:
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
# it has a shape like (batch_size, 1, 1, 1, 1)
param_rms = _mean(p**2, exclude_dims=[0], keepdim=True).sqrt() + eps
state["param_rms"] = param_rms
@ -433,6 +435,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
delta = state["delta"]
# the factor of (1-beta1) relates to momentum.
if random.random() < 0.01:
logging.info(f"scale_step ={scale_step}, shape={p.shape}")
delta.add_(p * scale_step, alpha=(1-beta1))
def _update_param_cov(self,
@ -1658,9 +1662,15 @@ def _test_eve_cain():
for iter in [3, 2, 1, 0]:
fix_random_seed(42)
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
# TODO: find out why this is not converging...
m = torch.nn.Sequential(Linear(E, 200),
torch.nn.ReLU(),
Linear(200, E)).to(device)
ActivationBalancer(-1),
torch.nn.PReLU(),
Linear(200, 200),
ActivationBalancer(-1),
torch.nn.PReLU(),
Linear(200, E),
).to(device)
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ]
@ -1693,16 +1703,16 @@ def _test_eve_cain():
else:
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
if n == 0 and epoch % 5 == 0:
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
#norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
#norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
#norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
#norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
#scale1 = '%.2e' % (m[0].weight_scale.exp().item())
#scale1b = '%.2e' % (m[0].bias_scale.exp().item())
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
lr = scheduler.get_last_lr()[0]
logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
loss.log().backward()
optim.step()
optim.zero_grad()
@ -1739,9 +1749,9 @@ def _test_svd():
assert torch.allclose(X2, X, atol=0.001)
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
logging.getLogger().setLevel(logging.INFO)
#_test_svd()
_test_eve_cain()
#_test_eden()