mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Test with CUDA, bug fixes
This commit is contained in:
parent
f8c7e6ffb3
commit
a359bfe504
@ -186,7 +186,8 @@ def compute_beta(P, K):
|
||||
remainder_k = Q_part - (B_k * Kk) # shape (*, K)
|
||||
|
||||
large_int = (2**32 - 1)
|
||||
R_part1 = torch.cat((R[...,M-K+1:M], torch.full((*R.shape[:-1], 1), large_int)), dim=-1)
|
||||
R_part1 = torch.cat((R[...,M-K+1:M], torch.full((*R.shape[:-1], 1), large_int,
|
||||
device=R.device)), dim=-1)
|
||||
R_part2 = R[...,M-K:M]
|
||||
|
||||
# is_ok corresponds to: "(k==0 or R[M-k] > B_k) and R[M-1-k] <= B_k" in NOTES.md
|
||||
@ -276,7 +277,8 @@ def compute_beta_prods(Psum, Ptop):
|
||||
large_int = (2**63 - 1)
|
||||
# Ptop_shifted is Ptop shifted right with a large value put first, i.e.
|
||||
# instead of [top1, top2, top3, top4] we have [inf, top1, top2, top3]
|
||||
Ptop_shifted = torch.cat((torch.full((*Ptop.shape[:-1], 1), large_int),
|
||||
Ptop_shifted = torch.cat((torch.full((*Ptop.shape[:-1], 1), large_int,
|
||||
device=Ptop.device),
|
||||
Ptop[...,:K-1]), dim=-1)
|
||||
|
||||
|
||||
@ -636,7 +638,7 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens
|
||||
# will not be sufficiently
|
||||
# random!! We need to leave some headroom.
|
||||
# rand_values are random in {0, 1, ..., B-1}
|
||||
rand = torch.randint((2**63 - 1), B.shape) % B
|
||||
rand = torch.randint((2**63 - 1), B.shape, device=B.device) % B
|
||||
# rand, rand + B, rand + 2B, ...., rand + (K-1)B
|
||||
samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K, device=B.device)
|
||||
|
||||
@ -786,7 +788,8 @@ def soft_sample_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor,
|
||||
P = (p*two31 + 1).to(dtype=torch.long)
|
||||
B = compute_beta(P, K)
|
||||
beta = B / two31
|
||||
t = torch.randint(M//2, p.shape[:-1] + (1,)) # shape: *, 1
|
||||
t = torch.randint(M//2, p.shape[:-1] + (1,),
|
||||
device=P.device) # shape: *, 1
|
||||
s = t * 2 + 1
|
||||
#s = torch.ones_like(t)
|
||||
|
||||
@ -801,7 +804,7 @@ def soft_sample_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor,
|
||||
S = torch.cumsum(R, dim=-1)
|
||||
|
||||
# Let b be a random integer drawn uniformly from {0, 1, ..., B-1}.
|
||||
b = torch.randint((2**63 - 1), B.shape) % B
|
||||
b = torch.randint((2**63 - 1), B.shape, device=B.device) % B
|
||||
|
||||
S_prev = torch.cat((torch.zeros(*S.shape[:-1], 1, device=S.device), S[...,:-1]), dim=-1)
|
||||
|
||||
@ -1156,9 +1159,11 @@ def _test_knowledge_base_lookup():
|
||||
print("knowlege_base.grad norm = ", knowledge_base.grad.norm())
|
||||
|
||||
|
||||
train_pairs = [ (torch.randn(B, T, E), torch.randn(B, T, E)) for _ in range(11) ]
|
||||
device = torch.device('cuda')
|
||||
train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(11) ]
|
||||
from optim import Eve
|
||||
optimizer = Eve(m.parameters(), lr=0.005)
|
||||
m = m.to(device)
|
||||
|
||||
for epoch in range(100):
|
||||
for n, (x,y) in enumerate(train_pairs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user