mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Update backprop of sampling.py to be slightly more efficient.
This commit is contained in:
parent
bbfa484196
commit
edaaec09cd
@ -3,6 +3,7 @@
|
||||
# This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py,
|
||||
# its git history is there.
|
||||
|
||||
import timeit
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
@ -874,28 +875,53 @@ def weighted_matrix_lookup(weights: Tensor,
|
||||
|
||||
|
||||
class WeightedMatrixLookupFunction(torch.autograd.Function):
|
||||
"""
|
||||
Weighted matrix lookup, memory efficient version that redoes the computation in the
|
||||
backward pass... this is not really optimal but the autograd for this operation is
|
||||
complicated.
|
||||
|
||||
See weighted_matrix_lookup() for documentation.
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor:
|
||||
"""
|
||||
Weighted combination of specified rows of a matrix.
|
||||
weights: Tensor of shape (*, K), can contain any value but probably in [0..1].
|
||||
indexes: Tensor of shape (*, K), with elements in [0..C-1]
|
||||
knowledge_base: Tensor of shape (C, D), whose rows we'll be looking up
|
||||
Returns:
|
||||
tensor of shape (*, D), containing weighted sums of rows of
|
||||
`knowledge_base`
|
||||
"""
|
||||
ctx.save_for_backward(weights.detach(), indexes.detach(),
|
||||
knowledge_base.detach())
|
||||
return weighted_matrix_lookup(weights, indexes, knowledge_base)
|
||||
with torch.no_grad():
|
||||
lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten())
|
||||
D = knowledge_base.shape[-1]
|
||||
weights = weights.unsqueeze(-2) # (*, 1, K)
|
||||
lookup = lookup.reshape(*indexes.shape, D) # (*, K, D)
|
||||
ans = torch.matmul(weights, lookup) # ans: (*, 1, D)
|
||||
ans = ans.squeeze(-2) #(*, D)
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, Tensor]:
|
||||
# ans_grad: (*, D)
|
||||
weights, indexes, knowledge_base = ctx.saved_tensors
|
||||
weights.requires_grad = True
|
||||
knowledge_base.requires_grad = True
|
||||
assert weights.requires_grad == False
|
||||
D = knowledge_base.shape[-1]
|
||||
with torch.enable_grad():
|
||||
ans = weighted_matrix_lookup(weights, indexes, knowledge_base)
|
||||
ans.backward(gradient=ans_grad)
|
||||
return weights.grad, None, knowledge_base.grad
|
||||
# we'll use torch's autograd to differentiate this operation, which
|
||||
# is nontrivial [and anyway we need `lookup` to compute weight grad.
|
||||
# We don't save `lookup` because it's large, that is the reason
|
||||
# we override Torch autograd.
|
||||
lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten())
|
||||
lookup = lookup.reshape(*indexes.shape, D) # (*, K, D)
|
||||
weights = weights.unsqueeze(-1) # (*, K, 1)
|
||||
# forward pass: was:
|
||||
## ans = torch.matmul(weights, lookup)
|
||||
## ans: (*, 1, D)
|
||||
## ans = ans.squeeze(-2) # ans, ans_grad: (*, D)
|
||||
weights_grad = torch.matmul(lookup, # (*, K, D)
|
||||
ans_grad.unsqueeze(-1)) # (*, D, 1)
|
||||
weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K)
|
||||
lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D)
|
||||
lookup.backward(gradient=lookup_grad)
|
||||
return weights_grad, None, knowledge_base.grad
|
||||
|
||||
|
||||
class KnowledgeBaseLookup(nn.Module):
|
||||
@ -1131,7 +1157,6 @@ def _test_sample_combined_mean():
|
||||
# weights: (B, K)
|
||||
# indexes: (B, K, N)
|
||||
weights, indexes = sample_combined_forward(p, K, True)
|
||||
|
||||
sampled_p = torch.zeros_like(p)
|
||||
weights_expanded = weights.unsqueeze(-2).expand(*weights.shape[:-1], N, K)
|
||||
sampled_p.scatter_add_(dim=-1, index=indexes.transpose(-2, -1),
|
||||
@ -1145,13 +1170,13 @@ def _test_knowledge_base_lookup():
|
||||
N = 2
|
||||
M = 128
|
||||
D = 256
|
||||
E = 384
|
||||
E = 255
|
||||
|
||||
knowledge_base: nn.Parameter = create_knowledge_base(M, N, D)
|
||||
m = KnowledgeBaseLookup(M, N, D, K, E, knowledge_base)
|
||||
|
||||
B = 30
|
||||
T = 4
|
||||
T = 40
|
||||
x = torch.randn(B, T, E)
|
||||
x.requires_grad = True
|
||||
y = m(x)
|
||||
@ -1163,21 +1188,28 @@ def _test_knowledge_base_lookup():
|
||||
|
||||
|
||||
device = torch.device('cuda')
|
||||
train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(11) ]
|
||||
train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ]
|
||||
from optim import Eve
|
||||
optimizer = Eve(m.parameters(), lr=0.005)
|
||||
m = m.to(device)
|
||||
|
||||
for epoch in range(100):
|
||||
|
||||
start = timeit.default_timer()
|
||||
|
||||
for epoch in range(120):
|
||||
for n, (x,y) in enumerate(train_pairs):
|
||||
y_out = m(x)
|
||||
loss = ((y_out - y)**2).mean()
|
||||
if n % 10 == 0:
|
||||
if n % 10 == 0 and epoch % 10 == 0:
|
||||
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
stop = timeit.default_timer()
|
||||
print('Time taken: ', stop - start)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_sample_combined()
|
||||
@ -1186,4 +1218,3 @@ if __name__ == '__main__':
|
||||
_test_compute_beta()
|
||||
_test_soft_sample()
|
||||
_test_knowledge_base_lookup()
|
||||
#test_normalizer()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user