mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Apparently working version, with changed test-code topology
This commit is contained in:
parent
245d39b1bb
commit
7993c84cd6
@ -18,6 +18,7 @@ from collections import defaultdict
|
|||||||
from typing import List, Optional, Union, Tuple, List
|
from typing import List, Optional, Union, Tuple, List
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
import torch
|
import torch
|
||||||
|
from scaling import ActivationBalancer
|
||||||
import random
|
import random
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
@ -247,6 +248,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
if numel > 1:
|
if numel > 1:
|
||||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||||
# the parameter tensor.
|
# the parameter tensor.
|
||||||
|
# it has a shape like (batch_size, 1, 1, 1, 1)
|
||||||
param_rms = _mean(p**2, exclude_dims=[0], keepdim=True).sqrt() + eps
|
param_rms = _mean(p**2, exclude_dims=[0], keepdim=True).sqrt() + eps
|
||||||
state["param_rms"] = param_rms
|
state["param_rms"] = param_rms
|
||||||
|
|
||||||
@ -433,6 +435,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
|
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
# the factor of (1-beta1) relates to momentum.
|
# the factor of (1-beta1) relates to momentum.
|
||||||
|
if random.random() < 0.01:
|
||||||
|
logging.info(f"scale_step ={scale_step}, shape={p.shape}")
|
||||||
delta.add_(p * scale_step, alpha=(1-beta1))
|
delta.add_(p * scale_step, alpha=(1-beta1))
|
||||||
|
|
||||||
def _update_param_cov(self,
|
def _update_param_cov(self,
|
||||||
@ -1658,9 +1662,15 @@ def _test_eve_cain():
|
|||||||
for iter in [3, 2, 1, 0]:
|
for iter in [3, 2, 1, 0]:
|
||||||
fix_random_seed(42)
|
fix_random_seed(42)
|
||||||
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
||||||
|
# TODO: find out why this is not converging...
|
||||||
m = torch.nn.Sequential(Linear(E, 200),
|
m = torch.nn.Sequential(Linear(E, 200),
|
||||||
torch.nn.ReLU(),
|
ActivationBalancer(-1),
|
||||||
Linear(200, E)).to(device)
|
torch.nn.PReLU(),
|
||||||
|
Linear(200, 200),
|
||||||
|
ActivationBalancer(-1),
|
||||||
|
torch.nn.PReLU(),
|
||||||
|
Linear(200, E),
|
||||||
|
).to(device)
|
||||||
|
|
||||||
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
|
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
|
||||||
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ]
|
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ]
|
||||||
@ -1693,16 +1703,16 @@ def _test_eve_cain():
|
|||||||
else:
|
else:
|
||||||
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
|
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
|
||||||
if n == 0 and epoch % 5 == 0:
|
if n == 0 and epoch % 5 == 0:
|
||||||
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
#norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
||||||
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
#norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
||||||
norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
#norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
||||||
norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
|
#norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
|
||||||
#scale1 = '%.2e' % (m[0].weight_scale.exp().item())
|
#scale1 = '%.2e' % (m[0].weight_scale.exp().item())
|
||||||
#scale1b = '%.2e' % (m[0].bias_scale.exp().item())
|
#scale1b = '%.2e' % (m[0].bias_scale.exp().item())
|
||||||
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
||||||
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
||||||
lr = scheduler.get_last_lr()[0]
|
lr = scheduler.get_last_lr()[0]
|
||||||
logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
||||||
loss.log().backward()
|
loss.log().backward()
|
||||||
optim.step()
|
optim.step()
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
@ -1739,9 +1749,9 @@ def _test_svd():
|
|||||||
assert torch.allclose(X2, X, atol=0.001)
|
assert torch.allclose(X2, X, atol=0.001)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
#_test_svd()
|
#_test_svd()
|
||||||
_test_eve_cain()
|
_test_eve_cain()
|
||||||
#_test_eden()
|
#_test_eden()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user