mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Apparently working version, with changed test-code topology
This commit is contained in:
parent
245d39b1bb
commit
7993c84cd6
@ -18,6 +18,7 @@ from collections import defaultdict
|
||||
from typing import List, Optional, Union, Tuple, List
|
||||
from lhotse.utils import fix_random_seed
|
||||
import torch
|
||||
from scaling import ActivationBalancer
|
||||
import random
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
@ -247,6 +248,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
if numel > 1:
|
||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||
# the parameter tensor.
|
||||
# it has a shape like (batch_size, 1, 1, 1, 1)
|
||||
param_rms = _mean(p**2, exclude_dims=[0], keepdim=True).sqrt() + eps
|
||||
state["param_rms"] = param_rms
|
||||
|
||||
@ -433,6 +435,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
|
||||
delta = state["delta"]
|
||||
# the factor of (1-beta1) relates to momentum.
|
||||
if random.random() < 0.01:
|
||||
logging.info(f"scale_step ={scale_step}, shape={p.shape}")
|
||||
delta.add_(p * scale_step, alpha=(1-beta1))
|
||||
|
||||
def _update_param_cov(self,
|
||||
@ -1658,9 +1662,15 @@ def _test_eve_cain():
|
||||
for iter in [3, 2, 1, 0]:
|
||||
fix_random_seed(42)
|
||||
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
||||
# TODO: find out why this is not converging...
|
||||
m = torch.nn.Sequential(Linear(E, 200),
|
||||
torch.nn.ReLU(),
|
||||
Linear(200, E)).to(device)
|
||||
ActivationBalancer(-1),
|
||||
torch.nn.PReLU(),
|
||||
Linear(200, 200),
|
||||
ActivationBalancer(-1),
|
||||
torch.nn.PReLU(),
|
||||
Linear(200, E),
|
||||
).to(device)
|
||||
|
||||
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
|
||||
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ]
|
||||
@ -1693,16 +1703,16 @@ def _test_eve_cain():
|
||||
else:
|
||||
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
|
||||
if n == 0 and epoch % 5 == 0:
|
||||
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
||||
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
||||
norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
||||
norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
|
||||
#norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
||||
#norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
||||
#norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
||||
#norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
|
||||
#scale1 = '%.2e' % (m[0].weight_scale.exp().item())
|
||||
#scale1b = '%.2e' % (m[0].bias_scale.exp().item())
|
||||
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
||||
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
||||
lr = scheduler.get_last_lr()[0]
|
||||
logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
||||
logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
||||
loss.log().backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
@ -1739,9 +1749,9 @@ def _test_svd():
|
||||
assert torch.allclose(X2, X, atol=0.001)
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
#_test_svd()
|
||||
_test_eve_cain()
|
||||
#_test_eden()
|
||||
|
Loading…
x
Reference in New Issue
Block a user