From df39fc6783b7b62543da3d54826af8b3a10b2646 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 24 Apr 2022 22:48:52 +0800 Subject: [PATCH] Fix devices --- egs/librispeech/ASR/pruned2_knowledge/sampling.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 85df5a89a..bfa3d0768 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -259,7 +259,7 @@ def compute_beta_prods(Psum, Ptop): # Shape is (*, K) S1 = Psum.unsqueeze(-1) - Ptop_cum_shift - temp = torch.arange(K, -1, -1) # [K, K-1, ..., 0] + temp = torch.arange(K, -1, -1, device=Psum.device) # [K, K-1, ..., 0] # Kk, of shape (K,), contains [K, K-1, ..., 1], representing K-k for k = [0, 1, ..., K-1] Kk = temp[0:K] # Kk1 of shape (K,), contains [K-1, K-2, ..., 0], representing K-k-1 for k = [0, 1, ..., K-1] @@ -637,7 +637,7 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens # rand_values are random in {0, 1, ..., B-1} rand = torch.randint((2**63 - 1), B.shape) % B # rand, rand + B, rand + 2B, ...., rand + (K-1)B - samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K) + samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K, device=B.device) shifted_samples = compute_shifted_samples(combined_cumsums_mod, delta_P, @@ -794,7 +794,7 @@ def soft_sample_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, assert torch.all((s * inv_s) % M == 1) # if this fails, check that M is a power of 2 # R = pseudo-random re-ordering of p. - R = torch.minimum(torch.gather(P, dim=-1, index=(s * torch.arange(M)) % M), + R = torch.minimum(torch.gather(P, dim=-1, index=(s * torch.arange(M, device=P.device)) % M), B) # S = inclusive-sum of R S = torch.cumsum(R, dim=-1) @@ -802,7 +802,7 @@ def soft_sample_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, # Let b be a random integer drawn uniformly from {0, 1, ..., B-1}. b = torch.randint((2**63 - 1), B.shape) % B - S_prev = torch.cat((torch.zeros(*S.shape[:-1], 1), S[...,:-1]), dim=-1) + S_prev = torch.cat((torch.zeros(*S.shape[:-1], 1, device=S.device), S[...,:-1]), dim=-1) k_prev = (S_prev + b) // B k_cur = (S + b) // B @@ -1060,7 +1060,7 @@ def _test_combined(): # rand_values are random in {0, 1, ..., B-1} rand = torch.randint((2**63 - 1), B.shape) % B # rand, rand + B, rand + 2B, ...., rand + (K-1)B - samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K) + samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K, device=B.device) print("rand = ", rand) print("sampled = ", samples) @@ -1150,6 +1150,10 @@ def _test_knowledge_base_lookup(): print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) + + #train_pairs = [ (torch.randn(B, T, E), torch.randn(B, T, E)) for _ in range(100) ] + + if __name__ == '__main__': _test_sample_combined() _test_sample_combined_mean()