From 804f264ffd26885b31a16fbcc8d01ccbfb00a415 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 2 Aug 2022 06:33:31 +0800 Subject: [PATCH] Merge hidden_dim providing it as arg --- .../ASR/pruned_transducer_stateless7/optim.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 1464abf31..4bf53fde7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -2159,7 +2159,7 @@ def _test_eden(): logging.info(f"state dict = {scheduler.state_dict()}") -def _test_eve_cain(): +def _test_eve_cain(hidden_dim): import timeit from scaling import ScaledLinear E = 100 @@ -2182,7 +2182,6 @@ def _test_eve_cain(): fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear - hidden_dim = 768 m = torch.nn.Sequential(Linear(E, hidden_dim), torch.nn.PReLU(), Linear(hidden_dim, hidden_dim), @@ -2272,8 +2271,14 @@ if __name__ == "__main__": torch.set_num_interop_threads(1) logging.getLogger().setLevel(logging.INFO) import subprocess + import sys + if len(sys.argv) > 1: + hidden_dim = int(sys.argv[1]) + else: + hidden_dim = 200 s = subprocess.check_output("git status -uno .; git log -1", shell=True) + logging.info(f"hidden_dim = {hidden_dim}") logging.info(s) #_test_svd() - _test_eve_cain() + _test_eve_cain(hidden_dim) #_test_eden()