From edaaec09cd6b255583eff386b6070bb144340b39 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 Apr 2022 19:32:11 +0800 Subject: [PATCH] Update backprop of sampling.py to be slightly more efficient. --- .../ASR/pruned2_knowledge/sampling.py | 69 ++++++++++++++----- 1 file changed, 50 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index fa9502d20..ea662ff2d 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -3,6 +3,7 @@ # This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py, # its git history is there. +import timeit import torch from torch import Tensor from torch import nn @@ -874,28 +875,53 @@ def weighted_matrix_lookup(weights: Tensor, class WeightedMatrixLookupFunction(torch.autograd.Function): - """ - Weighted matrix lookup, memory efficient version that redoes the computation in the - backward pass... this is not really optimal but the autograd for this operation is - complicated. - - See weighted_matrix_lookup() for documentation. - """ @staticmethod def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: + """ + Weighted combination of specified rows of a matrix. + weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. + indexes: Tensor of shape (*, K), with elements in [0..C-1] + knowledge_base: Tensor of shape (C, D), whose rows we'll be looking up + Returns: + tensor of shape (*, D), containing weighted sums of rows of + `knowledge_base` + """ ctx.save_for_backward(weights.detach(), indexes.detach(), knowledge_base.detach()) - return weighted_matrix_lookup(weights, indexes, knowledge_base) + with torch.no_grad(): + lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) + D = knowledge_base.shape[-1] + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + ans = ans.squeeze(-2) #(*, D) + return ans @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, Tensor]: + # ans_grad: (*, D) weights, indexes, knowledge_base = ctx.saved_tensors - weights.requires_grad = True knowledge_base.requires_grad = True + assert weights.requires_grad == False + D = knowledge_base.shape[-1] with torch.enable_grad(): - ans = weighted_matrix_lookup(weights, indexes, knowledge_base) - ans.backward(gradient=ans_grad) - return weights.grad, None, knowledge_base.grad + # we'll use torch's autograd to differentiate this operation, which + # is nontrivial [and anyway we need `lookup` to compute weight grad. + # We don't save `lookup` because it's large, that is the reason + # we override Torch autograd. + lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + weights = weights.unsqueeze(-1) # (*, K, 1) + # forward pass: was: + ## ans = torch.matmul(weights, lookup) + ## ans: (*, 1, D) + ## ans = ans.squeeze(-2) # ans, ans_grad: (*, D) + weights_grad = torch.matmul(lookup, # (*, K, D) + ans_grad.unsqueeze(-1)) # (*, D, 1) + weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) + lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D) + lookup.backward(gradient=lookup_grad) + return weights_grad, None, knowledge_base.grad class KnowledgeBaseLookup(nn.Module): @@ -1131,7 +1157,6 @@ def _test_sample_combined_mean(): # weights: (B, K) # indexes: (B, K, N) weights, indexes = sample_combined_forward(p, K, True) - sampled_p = torch.zeros_like(p) weights_expanded = weights.unsqueeze(-2).expand(*weights.shape[:-1], N, K) sampled_p.scatter_add_(dim=-1, index=indexes.transpose(-2, -1), @@ -1145,13 +1170,13 @@ def _test_knowledge_base_lookup(): N = 2 M = 128 D = 256 - E = 384 + E = 255 knowledge_base: nn.Parameter = create_knowledge_base(M, N, D) m = KnowledgeBaseLookup(M, N, D, K, E, knowledge_base) B = 30 - T = 4 + T = 40 x = torch.randn(B, T, E) x.requires_grad = True y = m(x) @@ -1163,21 +1188,28 @@ def _test_knowledge_base_lookup(): device = torch.device('cuda') - train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(11) ] + train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ] from optim import Eve optimizer = Eve(m.parameters(), lr=0.005) m = m.to(device) - for epoch in range(100): + + start = timeit.default_timer() + + for epoch in range(120): for n, (x,y) in enumerate(train_pairs): y_out = m(x) loss = ((y_out - y)**2).mean() - if n % 10 == 0: + if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") loss.backward() optimizer.step() optimizer.zero_grad() + stop = timeit.default_timer() + print('Time taken: ', stop - start) + + if __name__ == '__main__': _test_sample_combined() @@ -1186,4 +1218,3 @@ if __name__ == '__main__': _test_compute_beta() _test_soft_sample() _test_knowledge_base_lookup() - #test_normalizer()