Add more custom_fwd,custom_bwd'

This commit is contained in:
Daniel Povey 2022-04-25 23:58:34 +08:00
parent 2c4478b6d1
commit 3ba081e6d9
2 changed files with 68 additions and 4 deletions

View File

@ -7,7 +7,7 @@ import timeit
import torch import torch
from torch import Tensor from torch import Tensor
from torch import nn from torch import nn
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler, custom_fwd, custom_bwd
from typing import Tuple, Optional from typing import Tuple, Optional
from scaling import ScaledLinear from scaling import ScaledLinear
import random import random
@ -651,7 +651,7 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens
samples) samples)
# TODO: could remove the next call # TODO: could remove the next call
if random.random() < 0.01: if random.random() < 0.0005:
check_shifted_samples(combined_cumsums, delta_P, check_shifted_samples(combined_cumsums, delta_P,
shifted_samples, P_sum_product) shifted_samples, P_sum_product)
@ -727,7 +727,10 @@ class SampleCombinedFunction(torch.autograd.Function):
# please see sample_combined() or sample_combined_forward() or # please see sample_combined() or sample_combined_forward() or
# sample_combined_backward() for documentation # sample_combined_backward() for documentation
@staticmethod @staticmethod
@custom_fwd
def forward(ctx, p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, Tensor]: def forward(ctx, p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, Tensor]:
if random.random() < 0.0005:
print("dtype[1] = ", p.dtype)
with torch.no_grad(): with torch.no_grad():
weights, indexes = sample_combined_forward(p, K, input_is_log) weights, indexes = sample_combined_forward(p, K, input_is_log)
ctx.save_for_backward(p, indexes, weights) ctx.save_for_backward(p, indexes, weights)
@ -735,6 +738,7 @@ class SampleCombinedFunction(torch.autograd.Function):
return weights, indexes return weights, indexes
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, weights_grad: Optional[Tensor], indexes_grad: Optional[Tensor]) -> Tuple[Tensor, None, None]: def backward(ctx, weights_grad: Optional[Tensor], indexes_grad: Optional[Tensor]) -> Tuple[Tensor, None, None]:
p, indexes, weights = ctx.saved_tensors p, indexes, weights = ctx.saved_tensors
p_grad = sample_combined_backward(p, ctx.input_is_log, indexes, p_grad = sample_combined_backward(p, ctx.input_is_log, indexes,
@ -877,6 +881,7 @@ def weighted_matrix_lookup(weights: Tensor,
class WeightedMatrixLookupFunction(torch.autograd.Function): class WeightedMatrixLookupFunction(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd
def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor:
""" """
Weighted combination of specified rows of a matrix. Weighted combination of specified rows of a matrix.
@ -887,6 +892,8 @@ class WeightedMatrixLookupFunction(torch.autograd.Function):
tensor of shape (*, D), containing weighted sums of rows of tensor of shape (*, D), containing weighted sums of rows of
`knowledge_base` `knowledge_base`
""" """
if random.random() < 0.0005:
print("dtype[1] = ", weights.dtype)
ctx.save_for_backward(weights.detach(), indexes.detach(), ctx.save_for_backward(weights.detach(), indexes.detach(),
knowledge_base.detach()) knowledge_base.detach())
with torch.no_grad(): with torch.no_grad():
@ -899,10 +906,13 @@ class WeightedMatrixLookupFunction(torch.autograd.Function):
return ans return ans
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, Tensor]: def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, Tensor]:
# ans_grad: (*, D) # ans_grad: (*, D)
weights, indexes, knowledge_base = ctx.saved_tensors weights, indexes, knowledge_base = ctx.saved_tensors
knowledge_base.requires_grad = True knowledge_base.requires_grad = True
dtype = ans_grad.dtype
ans_grad = ans_grad.to(weights.dtype)
assert weights.requires_grad == False assert weights.requires_grad == False
D = knowledge_base.shape[-1] D = knowledge_base.shape[-1]
with torch.enable_grad(): with torch.enable_grad():
@ -922,7 +932,7 @@ class WeightedMatrixLookupFunction(torch.autograd.Function):
weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K)
lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D) lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D)
lookup.backward(gradient=lookup_grad) lookup.backward(gradient=lookup_grad)
return weights_grad, None, knowledge_base.grad return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype)
class KnowledgeBaseLookup(nn.Module): class KnowledgeBaseLookup(nn.Module):
@ -968,8 +978,9 @@ class KnowledgeBaseLookup(nn.Module):
x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M)
x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M)
assert torch.all(x - x == 0) assert torch.all(x - x == 0)
if random.random() < 0.001 or x.dtype == torch.float16: if random.random() < 0.001:
entropy = (x * x.exp()).sum(dim=-1).mean() entropy = (x * x.exp()).sum(dim=-1).mean()
print("Entropy = ", entropy)
weights, indexes, = sample_combined(x, self.K, input_is_log=True) weights, indexes, = sample_combined(x, self.K, input_is_log=True)
indexes = join_indexes(indexes, self.M) indexes = join_indexes(indexes, self.M)
x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D) x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D)
@ -1225,6 +1236,53 @@ def _test_knowledge_base_lookup():
stop = timeit.default_timer() stop = timeit.default_timer()
print('Time taken: ', stop - start) print('Time taken: ', stop - start)
def _test_knowledge_base_lookup_autocast():
K = 16
N = 2
M = 128
D = 256
E = 255
knowledge_base: nn.Parameter = create_knowledge_base(M, N, D)
m = KnowledgeBaseLookup(M, N, D, K, E, knowledge_base)
B = 30
T = 40
x = torch.randn(B, T, E)
x.requires_grad = True
y = m(x)
assert y.shape == x.shape
y.sum().backward() # make sure backward doesn't crash..
print("y = ", y)
print("x.grad = ", x.grad)
print("knowlege_base.grad norm = ", knowledge_base.grad.norm())
device = torch.device('cuda')
train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ]
from optim import Eve
optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04)
m = m.to(device)
scaler = GradScaler(enabled=True)
start = timeit.default_timer()
for epoch in range(120):
for n, (x,y) in enumerate(train_pairs):
y_out = m(x)
with torch.cuda.amp.autocast(enabled=True):
loss = ((y_out - y)**2).mean() * 100.0
if n % 10 == 0 and epoch % 10 == 0:
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
stop = timeit.default_timer()
print('Time taken: ', stop - start)
if __name__ == '__main__': if __name__ == '__main__':
@ -1233,4 +1291,5 @@ if __name__ == '__main__':
_test_combined() _test_combined()
_test_compute_beta() _test_compute_beta()
_test_soft_sample() _test_soft_sample()
_test_knowledge_base_lookup_autocast()
_test_knowledge_base_lookup() _test_knowledge_base_lookup()

View File

@ -18,6 +18,7 @@
import collections import collections
from itertools import repeat from itertools import repeat
from typing import Optional, Tuple from typing import Optional, Tuple
from torch.cuda.amp import custom_fwd, custom_bwd
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -39,6 +40,7 @@ _pair = _ntuple(2)
class ActivationBalancerFunction(torch.autograd.Function): class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd
def forward( def forward(
ctx, ctx,
x: Tensor, x: Tensor,
@ -85,6 +87,7 @@ class ActivationBalancerFunction(torch.autograd.Function):
return x return x
@staticmethod @staticmethod
@custom_bwd
def backward( def backward(
ctx, x_grad: Tensor ctx, x_grad: Tensor
) -> Tuple[Tensor, None, None, None, None, None, None]: ) -> Tuple[Tensor, None, None, None, None, None, None]:
@ -426,6 +429,7 @@ class DoubleSwishFunction(torch.autograd.Function):
""" """
@staticmethod @staticmethod
@custom_fwd
def forward(ctx, x: Tensor) -> Tensor: def forward(ctx, x: Tensor) -> Tensor:
x = x.detach() x = x.detach()
s = torch.sigmoid(x - 1.0) s = torch.sigmoid(x - 1.0)
@ -434,6 +438,7 @@ class DoubleSwishFunction(torch.autograd.Function):
return y return y
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, y_grad: Tensor) -> Tensor: def backward(ctx, y_grad: Tensor) -> Tensor:
s, y = ctx.saved_tensors s, y = ctx.saved_tensors
return (y * (1 - s) + s) * y_grad return (y * (1 - s) + s) * y_grad