diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 945e3ab19..cfbaab161 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -2180,7 +2180,7 @@ def _test_eden(): logging.info(f"state dict = {scheduler.state_dict()}") -def _test_eve_cain(): +def _test_eve_cain(hidden_dim): import timeit from scaling import ScaledLinear E = 100 @@ -2203,7 +2203,6 @@ def _test_eve_cain(): fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear - hidden_dim = 400 m = torch.nn.Sequential(Linear(E, hidden_dim), torch.nn.PReLU(), Linear(hidden_dim, hidden_dim), @@ -2293,8 +2292,14 @@ if __name__ == "__main__": torch.set_num_interop_threads(1) logging.getLogger().setLevel(logging.INFO) import subprocess + import sys + if len(sys.argv) > 1: + hidden_dim = int(sys.argv[1]) + else: + hidden_dim = 200 s = subprocess.check_output("git status -uno .; git log -1", shell=True) + logging.info(f"hidden_dim = {hidden_dim}") logging.info(s) #_test_svd() - _test_eve_cain() + _test_eve_cain(hidden_dim) #_test_eden()