diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 494894545..823005ab9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -1837,7 +1837,7 @@ def _test_eden(): logging.info(f"state dict = {scheduler.state_dict()}") -def _test_eve_cain(): +def _test_eve_cain(hidden_dim): import timeit from scaling import ScaledLinear E = 100 @@ -1860,7 +1860,6 @@ def _test_eve_cain(): fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear - hidden_dim = 300 m = torch.nn.Sequential(Linear(E, hidden_dim), torch.nn.PReLU(), Linear(hidden_dim, hidden_dim), @@ -1949,9 +1948,15 @@ if __name__ == "__main__": torch.set_num_interop_threads(1) logging.getLogger().setLevel(logging.INFO) import subprocess + import sys + if len(sys.argv) > 1: + hidden_dim = int(sys.argv[1]) + else: + hidden_dim = 200 s = subprocess.check_output("git status -uno .; git log -1", shell=True) _test_smooth_cov() + logging.info(f"hidden_dim = {hidden_dim}") logging.info(s) #_test_svd() - _test_eve_cain() + _test_eve_cain(hidden_dim) #_test_eden()