Update backprop of sampling.py to be slightly more efficient.

This commit is contained in:
Daniel Povey 2022-04-25 19:32:11 +08:00
parent bbfa484196
commit edaaec09cd

View File

@ -3,6 +3,7 @@
# This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py, # This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py,
# its git history is there. # its git history is there.
import timeit
import torch import torch
from torch import Tensor from torch import Tensor
from torch import nn from torch import nn
@ -874,28 +875,53 @@ def weighted_matrix_lookup(weights: Tensor,
class WeightedMatrixLookupFunction(torch.autograd.Function): class WeightedMatrixLookupFunction(torch.autograd.Function):
"""
Weighted matrix lookup, memory efficient version that redoes the computation in the
backward pass... this is not really optimal but the autograd for this operation is
complicated.
See weighted_matrix_lookup() for documentation.
"""
@staticmethod @staticmethod
def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor:
"""
Weighted combination of specified rows of a matrix.
weights: Tensor of shape (*, K), can contain any value but probably in [0..1].
indexes: Tensor of shape (*, K), with elements in [0..C-1]
knowledge_base: Tensor of shape (C, D), whose rows we'll be looking up
Returns:
tensor of shape (*, D), containing weighted sums of rows of
`knowledge_base`
"""
ctx.save_for_backward(weights.detach(), indexes.detach(), ctx.save_for_backward(weights.detach(), indexes.detach(),
knowledge_base.detach()) knowledge_base.detach())
return weighted_matrix_lookup(weights, indexes, knowledge_base) with torch.no_grad():
lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten())
D = knowledge_base.shape[-1]
weights = weights.unsqueeze(-2) # (*, 1, K)
lookup = lookup.reshape(*indexes.shape, D) # (*, K, D)
ans = torch.matmul(weights, lookup) # ans: (*, 1, D)
ans = ans.squeeze(-2) #(*, D)
return ans
@staticmethod @staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, Tensor]: def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, Tensor]:
# ans_grad: (*, D)
weights, indexes, knowledge_base = ctx.saved_tensors weights, indexes, knowledge_base = ctx.saved_tensors
weights.requires_grad = True
knowledge_base.requires_grad = True knowledge_base.requires_grad = True
assert weights.requires_grad == False
D = knowledge_base.shape[-1]
with torch.enable_grad(): with torch.enable_grad():
ans = weighted_matrix_lookup(weights, indexes, knowledge_base) # we'll use torch's autograd to differentiate this operation, which
ans.backward(gradient=ans_grad) # is nontrivial [and anyway we need `lookup` to compute weight grad.
return weights.grad, None, knowledge_base.grad # We don't save `lookup` because it's large, that is the reason
# we override Torch autograd.
lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten())
lookup = lookup.reshape(*indexes.shape, D) # (*, K, D)
weights = weights.unsqueeze(-1) # (*, K, 1)
# forward pass: was:
## ans = torch.matmul(weights, lookup)
## ans: (*, 1, D)
## ans = ans.squeeze(-2) # ans, ans_grad: (*, D)
weights_grad = torch.matmul(lookup, # (*, K, D)
ans_grad.unsqueeze(-1)) # (*, D, 1)
weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K)
lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D)
lookup.backward(gradient=lookup_grad)
return weights_grad, None, knowledge_base.grad
class KnowledgeBaseLookup(nn.Module): class KnowledgeBaseLookup(nn.Module):
@ -1131,7 +1157,6 @@ def _test_sample_combined_mean():
# weights: (B, K) # weights: (B, K)
# indexes: (B, K, N) # indexes: (B, K, N)
weights, indexes = sample_combined_forward(p, K, True) weights, indexes = sample_combined_forward(p, K, True)
sampled_p = torch.zeros_like(p) sampled_p = torch.zeros_like(p)
weights_expanded = weights.unsqueeze(-2).expand(*weights.shape[:-1], N, K) weights_expanded = weights.unsqueeze(-2).expand(*weights.shape[:-1], N, K)
sampled_p.scatter_add_(dim=-1, index=indexes.transpose(-2, -1), sampled_p.scatter_add_(dim=-1, index=indexes.transpose(-2, -1),
@ -1145,13 +1170,13 @@ def _test_knowledge_base_lookup():
N = 2 N = 2
M = 128 M = 128
D = 256 D = 256
E = 384 E = 255
knowledge_base: nn.Parameter = create_knowledge_base(M, N, D) knowledge_base: nn.Parameter = create_knowledge_base(M, N, D)
m = KnowledgeBaseLookup(M, N, D, K, E, knowledge_base) m = KnowledgeBaseLookup(M, N, D, K, E, knowledge_base)
B = 30 B = 30
T = 4 T = 40
x = torch.randn(B, T, E) x = torch.randn(B, T, E)
x.requires_grad = True x.requires_grad = True
y = m(x) y = m(x)
@ -1163,21 +1188,28 @@ def _test_knowledge_base_lookup():
device = torch.device('cuda') device = torch.device('cuda')
train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(11) ] train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ]
from optim import Eve from optim import Eve
optimizer = Eve(m.parameters(), lr=0.005) optimizer = Eve(m.parameters(), lr=0.005)
m = m.to(device) m = m.to(device)
for epoch in range(100):
start = timeit.default_timer()
for epoch in range(120):
for n, (x,y) in enumerate(train_pairs): for n, (x,y) in enumerate(train_pairs):
y_out = m(x) y_out = m(x)
loss = ((y_out - y)**2).mean() loss = ((y_out - y)**2).mean()
if n % 10 == 0: if n % 10 == 0 and epoch % 10 == 0:
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
loss.backward() loss.backward()
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
stop = timeit.default_timer()
print('Time taken: ', stop - start)
if __name__ == '__main__': if __name__ == '__main__':
_test_sample_combined() _test_sample_combined()
@ -1186,4 +1218,3 @@ if __name__ == '__main__':
_test_compute_beta() _test_compute_beta()
_test_soft_sample() _test_soft_sample()
_test_knowledge_base_lookup() _test_knowledge_base_lookup()
#test_normalizer()