Merge making hidden_dim an arg

This commit is contained in:
Daniel Povey 2022-08-02 09:07:36 +08:00
commit c64bd5ebcd

View File

@ -1837,7 +1837,7 @@ def _test_eden():
logging.info(f"state dict = {scheduler.state_dict()}") logging.info(f"state dict = {scheduler.state_dict()}")
def _test_eve_cain(): def _test_eve_cain(hidden_dim):
import timeit import timeit
from scaling import ScaledLinear from scaling import ScaledLinear
E = 100 E = 100
@ -1860,7 +1860,6 @@ def _test_eve_cain():
fix_random_seed(42) fix_random_seed(42)
Linear = torch.nn.Linear if iter == 0 else ScaledLinear Linear = torch.nn.Linear if iter == 0 else ScaledLinear
hidden_dim = 300
m = torch.nn.Sequential(Linear(E, hidden_dim), m = torch.nn.Sequential(Linear(E, hidden_dim),
torch.nn.PReLU(), torch.nn.PReLU(),
Linear(hidden_dim, hidden_dim), Linear(hidden_dim, hidden_dim),
@ -1949,9 +1948,15 @@ if __name__ == "__main__":
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
import subprocess import subprocess
import sys
if len(sys.argv) > 1:
hidden_dim = int(sys.argv[1])
else:
hidden_dim = 200
s = subprocess.check_output("git status -uno .; git log -1", shell=True) s = subprocess.check_output("git status -uno .; git log -1", shell=True)
_test_smooth_cov() _test_smooth_cov()
logging.info(f"hidden_dim = {hidden_dim}")
logging.info(s) logging.info(s)
#_test_svd() #_test_svd()
_test_eve_cain() _test_eve_cain(hidden_dim)
#_test_eden() #_test_eden()