From a359bfe5047d5f7abde79d853d4f960f57fd7154 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 Apr 2022 13:19:09 +0800 Subject: [PATCH] Test with CUDA, bug fixes --- .../ASR/pruned2_knowledge/sampling.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index e68ea153b..b6aec23d7 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -186,7 +186,8 @@ def compute_beta(P, K): remainder_k = Q_part - (B_k * Kk) # shape (*, K) large_int = (2**32 - 1) - R_part1 = torch.cat((R[...,M-K+1:M], torch.full((*R.shape[:-1], 1), large_int)), dim=-1) + R_part1 = torch.cat((R[...,M-K+1:M], torch.full((*R.shape[:-1], 1), large_int, + device=R.device)), dim=-1) R_part2 = R[...,M-K:M] # is_ok corresponds to: "(k==0 or R[M-k] > B_k) and R[M-1-k] <= B_k" in NOTES.md @@ -276,7 +277,8 @@ def compute_beta_prods(Psum, Ptop): large_int = (2**63 - 1) # Ptop_shifted is Ptop shifted right with a large value put first, i.e. # instead of [top1, top2, top3, top4] we have [inf, top1, top2, top3] - Ptop_shifted = torch.cat((torch.full((*Ptop.shape[:-1], 1), large_int), + Ptop_shifted = torch.cat((torch.full((*Ptop.shape[:-1], 1), large_int, + device=Ptop.device), Ptop[...,:K-1]), dim=-1) @@ -636,7 +638,7 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens # will not be sufficiently # random!! We need to leave some headroom. # rand_values are random in {0, 1, ..., B-1} - rand = torch.randint((2**63 - 1), B.shape) % B + rand = torch.randint((2**63 - 1), B.shape, device=B.device) % B # rand, rand + B, rand + 2B, ...., rand + (K-1)B samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K, device=B.device) @@ -786,7 +788,8 @@ def soft_sample_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, P = (p*two31 + 1).to(dtype=torch.long) B = compute_beta(P, K) beta = B / two31 - t = torch.randint(M//2, p.shape[:-1] + (1,)) # shape: *, 1 + t = torch.randint(M//2, p.shape[:-1] + (1,), + device=P.device) # shape: *, 1 s = t * 2 + 1 #s = torch.ones_like(t) @@ -801,7 +804,7 @@ def soft_sample_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, S = torch.cumsum(R, dim=-1) # Let b be a random integer drawn uniformly from {0, 1, ..., B-1}. - b = torch.randint((2**63 - 1), B.shape) % B + b = torch.randint((2**63 - 1), B.shape, device=B.device) % B S_prev = torch.cat((torch.zeros(*S.shape[:-1], 1, device=S.device), S[...,:-1]), dim=-1) @@ -1156,9 +1159,11 @@ def _test_knowledge_base_lookup(): print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) - train_pairs = [ (torch.randn(B, T, E), torch.randn(B, T, E)) for _ in range(11) ] + device = torch.device('cuda') + train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(11) ] from optim import Eve optimizer = Eve(m.parameters(), lr=0.005) + m = m.to(device) for epoch in range(100): for n, (x,y) in enumerate(train_pairs):