diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 8627c1f9a..5b960e30e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -18,6 +18,7 @@ from collections import defaultdict from typing import List, Optional, Union, Tuple, List from lhotse.utils import fix_random_seed import torch +from scaling import ActivationBalancer import random from torch import Tensor from torch.optim import Optimizer @@ -247,6 +248,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of if numel > 1: # "param_rms" just periodically records the scalar root-mean-square value of # the parameter tensor. + # it has a shape like (batch_size, 1, 1, 1, 1) param_rms = _mean(p**2, exclude_dims=[0], keepdim=True).sqrt() + eps state["param_rms"] = param_rms @@ -433,6 +435,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of delta = state["delta"] # the factor of (1-beta1) relates to momentum. + if random.random() < 0.01: + logging.info(f"scale_step ={scale_step}, shape={p.shape}") delta.add_(p * scale_step, alpha=(1-beta1)) def _update_param_cov(self, @@ -1658,9 +1662,15 @@ def _test_eve_cain(): for iter in [3, 2, 1, 0]: fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear + # TODO: find out why this is not converging... m = torch.nn.Sequential(Linear(E, 200), - torch.nn.ReLU(), - Linear(200, E)).to(device) + ActivationBalancer(-1), + torch.nn.PReLU(), + Linear(200, 200), + ActivationBalancer(-1), + torch.nn.PReLU(), + Linear(200, E), + ).to(device) train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] @@ -1693,16 +1703,16 @@ def _test_eve_cain(): else: avg_loss = 0.95 * avg_loss + 0.05 * loss.item() if n == 0 and epoch % 5 == 0: - norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + #norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + #norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + #norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + #norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() #scale1 = '%.2e' % (m[0].weight_scale.exp().item()) #scale1b = '%.2e' % (m[0].bias_scale.exp().item()) #scale2 = '%.2e' % (m[2].weight_scale.exp().item()) #scale2b = '%.2e' % (m[2].bias_scale.exp().item()) lr = scheduler.get_last_lr()[0] - logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} loss.log().backward() optim.step() optim.zero_grad() @@ -1739,9 +1749,9 @@ def _test_svd(): assert torch.allclose(X2, X, atol=0.001) if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) + logging.getLogger().setLevel(logging.INFO) #_test_svd() _test_eve_cain() #_test_eden()