mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge hidden_dim providing it as arg
This commit is contained in:
parent
3c1fddaf48
commit
804f264ffd
@ -2159,7 +2159,7 @@ def _test_eden():
|
|||||||
logging.info(f"state dict = {scheduler.state_dict()}")
|
logging.info(f"state dict = {scheduler.state_dict()}")
|
||||||
|
|
||||||
|
|
||||||
def _test_eve_cain():
|
def _test_eve_cain(hidden_dim):
|
||||||
import timeit
|
import timeit
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
E = 100
|
E = 100
|
||||||
@ -2182,7 +2182,6 @@ def _test_eve_cain():
|
|||||||
fix_random_seed(42)
|
fix_random_seed(42)
|
||||||
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
||||||
|
|
||||||
hidden_dim = 768
|
|
||||||
m = torch.nn.Sequential(Linear(E, hidden_dim),
|
m = torch.nn.Sequential(Linear(E, hidden_dim),
|
||||||
torch.nn.PReLU(),
|
torch.nn.PReLU(),
|
||||||
Linear(hidden_dim, hidden_dim),
|
Linear(hidden_dim, hidden_dim),
|
||||||
@ -2272,8 +2271,14 @@ if __name__ == "__main__":
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
hidden_dim = int(sys.argv[1])
|
||||||
|
else:
|
||||||
|
hidden_dim = 200
|
||||||
s = subprocess.check_output("git status -uno .; git log -1", shell=True)
|
s = subprocess.check_output("git status -uno .; git log -1", shell=True)
|
||||||
|
logging.info(f"hidden_dim = {hidden_dim}")
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
#_test_svd()
|
#_test_svd()
|
||||||
_test_eve_cain()
|
_test_eve_cain(hidden_dim)
|
||||||
#_test_eden()
|
#_test_eden()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user