Fix devices

This commit is contained in:
Daniel Povey 2022-04-24 22:48:52 +08:00
parent a266922678
commit df39fc6783

View File

@ -259,7 +259,7 @@ def compute_beta_prods(Psum, Ptop):
# Shape is (*, K)
S1 = Psum.unsqueeze(-1) - Ptop_cum_shift
temp = torch.arange(K, -1, -1) # [K, K-1, ..., 0]
temp = torch.arange(K, -1, -1, device=Psum.device) # [K, K-1, ..., 0]
# Kk, of shape (K,), contains [K, K-1, ..., 1], representing K-k for k = [0, 1, ..., K-1]
Kk = temp[0:K]
# Kk1 of shape (K,), contains [K-1, K-2, ..., 0], representing K-k-1 for k = [0, 1, ..., K-1]
@ -637,7 +637,7 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens
# rand_values are random in {0, 1, ..., B-1}
rand = torch.randint((2**63 - 1), B.shape) % B
# rand, rand + B, rand + 2B, ...., rand + (K-1)B
samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K)
samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K, device=B.device)
shifted_samples = compute_shifted_samples(combined_cumsums_mod,
delta_P,
@ -794,7 +794,7 @@ def soft_sample_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor,
assert torch.all((s * inv_s) % M == 1) # if this fails, check that M is a power of 2
# R = pseudo-random re-ordering of p.
R = torch.minimum(torch.gather(P, dim=-1, index=(s * torch.arange(M)) % M),
R = torch.minimum(torch.gather(P, dim=-1, index=(s * torch.arange(M, device=P.device)) % M),
B)
# S = inclusive-sum of R
S = torch.cumsum(R, dim=-1)
@ -802,7 +802,7 @@ def soft_sample_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor,
# Let b be a random integer drawn uniformly from {0, 1, ..., B-1}.
b = torch.randint((2**63 - 1), B.shape) % B
S_prev = torch.cat((torch.zeros(*S.shape[:-1], 1), S[...,:-1]), dim=-1)
S_prev = torch.cat((torch.zeros(*S.shape[:-1], 1, device=S.device), S[...,:-1]), dim=-1)
k_prev = (S_prev + b) // B
k_cur = (S + b) // B
@ -1060,7 +1060,7 @@ def _test_combined():
# rand_values are random in {0, 1, ..., B-1}
rand = torch.randint((2**63 - 1), B.shape) % B
# rand, rand + B, rand + 2B, ...., rand + (K-1)B
samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K)
samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K, device=B.device)
print("rand = ", rand)
print("sampled = ", samples)
@ -1150,6 +1150,10 @@ def _test_knowledge_base_lookup():
print("x.grad = ", x.grad)
print("knowlege_base.grad norm = ", knowledge_base.grad.norm())
#train_pairs = [ (torch.randn(B, T, E), torch.randn(B, T, E)) for _ in range(100) ]
if __name__ == '__main__':
_test_sample_combined()
_test_sample_combined_mean()