mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Merge making hidden_dim an arg
This commit is contained in:
commit
c64bd5ebcd
@ -1837,7 +1837,7 @@ def _test_eden():
|
||||
logging.info(f"state dict = {scheduler.state_dict()}")
|
||||
|
||||
|
||||
def _test_eve_cain():
|
||||
def _test_eve_cain(hidden_dim):
|
||||
import timeit
|
||||
from scaling import ScaledLinear
|
||||
E = 100
|
||||
@ -1860,7 +1860,6 @@ def _test_eve_cain():
|
||||
fix_random_seed(42)
|
||||
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
||||
|
||||
hidden_dim = 300
|
||||
m = torch.nn.Sequential(Linear(E, hidden_dim),
|
||||
torch.nn.PReLU(),
|
||||
Linear(hidden_dim, hidden_dim),
|
||||
@ -1949,9 +1948,15 @@ if __name__ == "__main__":
|
||||
torch.set_num_interop_threads(1)
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
import subprocess
|
||||
import sys
|
||||
if len(sys.argv) > 1:
|
||||
hidden_dim = int(sys.argv[1])
|
||||
else:
|
||||
hidden_dim = 200
|
||||
s = subprocess.check_output("git status -uno .; git log -1", shell=True)
|
||||
_test_smooth_cov()
|
||||
logging.info(f"hidden_dim = {hidden_dim}")
|
||||
logging.info(s)
|
||||
#_test_svd()
|
||||
_test_eve_cain()
|
||||
_test_eve_cain(hidden_dim)
|
||||
#_test_eden()
|
||||
|
Loading…
x
Reference in New Issue
Block a user