Test with CUDA, bug fixes

This commit is contained in:
Daniel Povey 2022-04-25 13:19:09 +08:00
parent f8c7e6ffb3
commit a359bfe504

View File

@ -186,7 +186,8 @@ def compute_beta(P, K):
remainder_k = Q_part - (B_k * Kk) # shape (*, K)
large_int = (2**32 - 1)
R_part1 = torch.cat((R[...,M-K+1:M], torch.full((*R.shape[:-1], 1), large_int)), dim=-1)
R_part1 = torch.cat((R[...,M-K+1:M], torch.full((*R.shape[:-1], 1), large_int,
device=R.device)), dim=-1)
R_part2 = R[...,M-K:M]
# is_ok corresponds to: "(k==0 or R[M-k] > B_k) and R[M-1-k] <= B_k" in NOTES.md
@ -276,7 +277,8 @@ def compute_beta_prods(Psum, Ptop):
large_int = (2**63 - 1)
# Ptop_shifted is Ptop shifted right with a large value put first, i.e.
# instead of [top1, top2, top3, top4] we have [inf, top1, top2, top3]
Ptop_shifted = torch.cat((torch.full((*Ptop.shape[:-1], 1), large_int),
Ptop_shifted = torch.cat((torch.full((*Ptop.shape[:-1], 1), large_int,
device=Ptop.device),
Ptop[...,:K-1]), dim=-1)
@ -636,7 +638,7 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens
# will not be sufficiently
# random!! We need to leave some headroom.
# rand_values are random in {0, 1, ..., B-1}
rand = torch.randint((2**63 - 1), B.shape) % B
rand = torch.randint((2**63 - 1), B.shape, device=B.device) % B
# rand, rand + B, rand + 2B, ...., rand + (K-1)B
samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K, device=B.device)
@ -786,7 +788,8 @@ def soft_sample_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor,
P = (p*two31 + 1).to(dtype=torch.long)
B = compute_beta(P, K)
beta = B / two31
t = torch.randint(M//2, p.shape[:-1] + (1,)) # shape: *, 1
t = torch.randint(M//2, p.shape[:-1] + (1,),
device=P.device) # shape: *, 1
s = t * 2 + 1
#s = torch.ones_like(t)
@ -801,7 +804,7 @@ def soft_sample_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor,
S = torch.cumsum(R, dim=-1)
# Let b be a random integer drawn uniformly from {0, 1, ..., B-1}.
b = torch.randint((2**63 - 1), B.shape) % B
b = torch.randint((2**63 - 1), B.shape, device=B.device) % B
S_prev = torch.cat((torch.zeros(*S.shape[:-1], 1, device=S.device), S[...,:-1]), dim=-1)
@ -1156,9 +1159,11 @@ def _test_knowledge_base_lookup():
print("knowlege_base.grad norm = ", knowledge_base.grad.norm())
train_pairs = [ (torch.randn(B, T, E), torch.randn(B, T, E)) for _ in range(11) ]
device = torch.device('cuda')
train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(11) ]
from optim import Eve
optimizer = Eve(m.parameters(), lr=0.005)
m = m.to(device)
for epoch in range(100):
for n, (x,y) in enumerate(train_pairs):