mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Update backprop of sampling.py to be slightly more efficient.
This commit is contained in:
parent
bbfa484196
commit
edaaec09cd
@ -3,6 +3,7 @@
|
|||||||
# This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py,
|
# This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py,
|
||||||
# its git history is there.
|
# its git history is there.
|
||||||
|
|
||||||
|
import timeit
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -874,28 +875,53 @@ def weighted_matrix_lookup(weights: Tensor,
|
|||||||
|
|
||||||
|
|
||||||
class WeightedMatrixLookupFunction(torch.autograd.Function):
|
class WeightedMatrixLookupFunction(torch.autograd.Function):
|
||||||
"""
|
|
||||||
Weighted matrix lookup, memory efficient version that redoes the computation in the
|
|
||||||
backward pass... this is not really optimal but the autograd for this operation is
|
|
||||||
complicated.
|
|
||||||
|
|
||||||
See weighted_matrix_lookup() for documentation.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor:
|
def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Weighted combination of specified rows of a matrix.
|
||||||
|
weights: Tensor of shape (*, K), can contain any value but probably in [0..1].
|
||||||
|
indexes: Tensor of shape (*, K), with elements in [0..C-1]
|
||||||
|
knowledge_base: Tensor of shape (C, D), whose rows we'll be looking up
|
||||||
|
Returns:
|
||||||
|
tensor of shape (*, D), containing weighted sums of rows of
|
||||||
|
`knowledge_base`
|
||||||
|
"""
|
||||||
ctx.save_for_backward(weights.detach(), indexes.detach(),
|
ctx.save_for_backward(weights.detach(), indexes.detach(),
|
||||||
knowledge_base.detach())
|
knowledge_base.detach())
|
||||||
return weighted_matrix_lookup(weights, indexes, knowledge_base)
|
with torch.no_grad():
|
||||||
|
lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten())
|
||||||
|
D = knowledge_base.shape[-1]
|
||||||
|
weights = weights.unsqueeze(-2) # (*, 1, K)
|
||||||
|
lookup = lookup.reshape(*indexes.shape, D) # (*, K, D)
|
||||||
|
ans = torch.matmul(weights, lookup) # ans: (*, 1, D)
|
||||||
|
ans = ans.squeeze(-2) #(*, D)
|
||||||
|
return ans
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, Tensor]:
|
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, Tensor]:
|
||||||
|
# ans_grad: (*, D)
|
||||||
weights, indexes, knowledge_base = ctx.saved_tensors
|
weights, indexes, knowledge_base = ctx.saved_tensors
|
||||||
weights.requires_grad = True
|
|
||||||
knowledge_base.requires_grad = True
|
knowledge_base.requires_grad = True
|
||||||
|
assert weights.requires_grad == False
|
||||||
|
D = knowledge_base.shape[-1]
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
ans = weighted_matrix_lookup(weights, indexes, knowledge_base)
|
# we'll use torch's autograd to differentiate this operation, which
|
||||||
ans.backward(gradient=ans_grad)
|
# is nontrivial [and anyway we need `lookup` to compute weight grad.
|
||||||
return weights.grad, None, knowledge_base.grad
|
# We don't save `lookup` because it's large, that is the reason
|
||||||
|
# we override Torch autograd.
|
||||||
|
lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten())
|
||||||
|
lookup = lookup.reshape(*indexes.shape, D) # (*, K, D)
|
||||||
|
weights = weights.unsqueeze(-1) # (*, K, 1)
|
||||||
|
# forward pass: was:
|
||||||
|
## ans = torch.matmul(weights, lookup)
|
||||||
|
## ans: (*, 1, D)
|
||||||
|
## ans = ans.squeeze(-2) # ans, ans_grad: (*, D)
|
||||||
|
weights_grad = torch.matmul(lookup, # (*, K, D)
|
||||||
|
ans_grad.unsqueeze(-1)) # (*, D, 1)
|
||||||
|
weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K)
|
||||||
|
lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D)
|
||||||
|
lookup.backward(gradient=lookup_grad)
|
||||||
|
return weights_grad, None, knowledge_base.grad
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeBaseLookup(nn.Module):
|
class KnowledgeBaseLookup(nn.Module):
|
||||||
@ -1131,7 +1157,6 @@ def _test_sample_combined_mean():
|
|||||||
# weights: (B, K)
|
# weights: (B, K)
|
||||||
# indexes: (B, K, N)
|
# indexes: (B, K, N)
|
||||||
weights, indexes = sample_combined_forward(p, K, True)
|
weights, indexes = sample_combined_forward(p, K, True)
|
||||||
|
|
||||||
sampled_p = torch.zeros_like(p)
|
sampled_p = torch.zeros_like(p)
|
||||||
weights_expanded = weights.unsqueeze(-2).expand(*weights.shape[:-1], N, K)
|
weights_expanded = weights.unsqueeze(-2).expand(*weights.shape[:-1], N, K)
|
||||||
sampled_p.scatter_add_(dim=-1, index=indexes.transpose(-2, -1),
|
sampled_p.scatter_add_(dim=-1, index=indexes.transpose(-2, -1),
|
||||||
@ -1145,13 +1170,13 @@ def _test_knowledge_base_lookup():
|
|||||||
N = 2
|
N = 2
|
||||||
M = 128
|
M = 128
|
||||||
D = 256
|
D = 256
|
||||||
E = 384
|
E = 255
|
||||||
|
|
||||||
knowledge_base: nn.Parameter = create_knowledge_base(M, N, D)
|
knowledge_base: nn.Parameter = create_knowledge_base(M, N, D)
|
||||||
m = KnowledgeBaseLookup(M, N, D, K, E, knowledge_base)
|
m = KnowledgeBaseLookup(M, N, D, K, E, knowledge_base)
|
||||||
|
|
||||||
B = 30
|
B = 30
|
||||||
T = 4
|
T = 40
|
||||||
x = torch.randn(B, T, E)
|
x = torch.randn(B, T, E)
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
y = m(x)
|
y = m(x)
|
||||||
@ -1163,21 +1188,28 @@ def _test_knowledge_base_lookup():
|
|||||||
|
|
||||||
|
|
||||||
device = torch.device('cuda')
|
device = torch.device('cuda')
|
||||||
train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(11) ]
|
train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ]
|
||||||
from optim import Eve
|
from optim import Eve
|
||||||
optimizer = Eve(m.parameters(), lr=0.005)
|
optimizer = Eve(m.parameters(), lr=0.005)
|
||||||
m = m.to(device)
|
m = m.to(device)
|
||||||
|
|
||||||
for epoch in range(100):
|
|
||||||
|
start = timeit.default_timer()
|
||||||
|
|
||||||
|
for epoch in range(120):
|
||||||
for n, (x,y) in enumerate(train_pairs):
|
for n, (x,y) in enumerate(train_pairs):
|
||||||
y_out = m(x)
|
y_out = m(x)
|
||||||
loss = ((y_out - y)**2).mean()
|
loss = ((y_out - y)**2).mean()
|
||||||
if n % 10 == 0:
|
if n % 10 == 0 and epoch % 10 == 0:
|
||||||
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
|
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
stop = timeit.default_timer()
|
||||||
|
print('Time taken: ', stop - start)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
_test_sample_combined()
|
_test_sample_combined()
|
||||||
@ -1186,4 +1218,3 @@ if __name__ == '__main__':
|
|||||||
_test_compute_beta()
|
_test_compute_beta()
|
||||||
_test_soft_sample()
|
_test_soft_sample()
|
||||||
_test_knowledge_base_lookup()
|
_test_knowledge_base_lookup()
|
||||||
#test_normalizer()
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user