Add more custom_fwd,custom_bwd'

This commit is contained in:
Daniel Povey 2022-04-25 23:58:34 +08:00
parent 2c4478b6d1
commit 3ba081e6d9
2 changed files with 68 additions and 4 deletions

View File

@ -7,7 +7,7 @@ import timeit
import torch
from torch import Tensor
from torch import nn
from torch.cuda.amp import GradScaler
from torch.cuda.amp import GradScaler, custom_fwd, custom_bwd
from typing import Tuple, Optional
from scaling import ScaledLinear
import random
@ -651,7 +651,7 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens
samples)
# TODO: could remove the next call
if random.random() < 0.01:
if random.random() < 0.0005:
check_shifted_samples(combined_cumsums, delta_P,
shifted_samples, P_sum_product)
@ -727,7 +727,10 @@ class SampleCombinedFunction(torch.autograd.Function):
# please see sample_combined() or sample_combined_forward() or
# sample_combined_backward() for documentation
@staticmethod
@custom_fwd
def forward(ctx, p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, Tensor]:
if random.random() < 0.0005:
print("dtype[1] = ", p.dtype)
with torch.no_grad():
weights, indexes = sample_combined_forward(p, K, input_is_log)
ctx.save_for_backward(p, indexes, weights)
@ -735,6 +738,7 @@ class SampleCombinedFunction(torch.autograd.Function):
return weights, indexes
@staticmethod
@custom_bwd
def backward(ctx, weights_grad: Optional[Tensor], indexes_grad: Optional[Tensor]) -> Tuple[Tensor, None, None]:
p, indexes, weights = ctx.saved_tensors
p_grad = sample_combined_backward(p, ctx.input_is_log, indexes,
@ -877,6 +881,7 @@ def weighted_matrix_lookup(weights: Tensor,
class WeightedMatrixLookupFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor:
"""
Weighted combination of specified rows of a matrix.
@ -887,6 +892,8 @@ class WeightedMatrixLookupFunction(torch.autograd.Function):
tensor of shape (*, D), containing weighted sums of rows of
`knowledge_base`
"""
if random.random() < 0.0005:
print("dtype[1] = ", weights.dtype)
ctx.save_for_backward(weights.detach(), indexes.detach(),
knowledge_base.detach())
with torch.no_grad():
@ -899,10 +906,13 @@ class WeightedMatrixLookupFunction(torch.autograd.Function):
return ans
@staticmethod
@custom_bwd
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, Tensor]:
# ans_grad: (*, D)
weights, indexes, knowledge_base = ctx.saved_tensors
knowledge_base.requires_grad = True
dtype = ans_grad.dtype
ans_grad = ans_grad.to(weights.dtype)
assert weights.requires_grad == False
D = knowledge_base.shape[-1]
with torch.enable_grad():
@ -922,7 +932,7 @@ class WeightedMatrixLookupFunction(torch.autograd.Function):
weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K)
lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D)
lookup.backward(gradient=lookup_grad)
return weights_grad, None, knowledge_base.grad
return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype)
class KnowledgeBaseLookup(nn.Module):
@ -968,8 +978,9 @@ class KnowledgeBaseLookup(nn.Module):
x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M)
x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M)
assert torch.all(x - x == 0)
if random.random() < 0.001 or x.dtype == torch.float16:
if random.random() < 0.001:
entropy = (x * x.exp()).sum(dim=-1).mean()
print("Entropy = ", entropy)
weights, indexes, = sample_combined(x, self.K, input_is_log=True)
indexes = join_indexes(indexes, self.M)
x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D)
@ -1225,6 +1236,53 @@ def _test_knowledge_base_lookup():
stop = timeit.default_timer()
print('Time taken: ', stop - start)
def _test_knowledge_base_lookup_autocast():
K = 16
N = 2
M = 128
D = 256
E = 255
knowledge_base: nn.Parameter = create_knowledge_base(M, N, D)
m = KnowledgeBaseLookup(M, N, D, K, E, knowledge_base)
B = 30
T = 40
x = torch.randn(B, T, E)
x.requires_grad = True
y = m(x)
assert y.shape == x.shape
y.sum().backward() # make sure backward doesn't crash..
print("y = ", y)
print("x.grad = ", x.grad)
print("knowlege_base.grad norm = ", knowledge_base.grad.norm())
device = torch.device('cuda')
train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ]
from optim import Eve
optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04)
m = m.to(device)
scaler = GradScaler(enabled=True)
start = timeit.default_timer()
for epoch in range(120):
for n, (x,y) in enumerate(train_pairs):
y_out = m(x)
with torch.cuda.amp.autocast(enabled=True):
loss = ((y_out - y)**2).mean() * 100.0
if n % 10 == 0 and epoch % 10 == 0:
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
stop = timeit.default_timer()
print('Time taken: ', stop - start)
if __name__ == '__main__':
@ -1233,4 +1291,5 @@ if __name__ == '__main__':
_test_combined()
_test_compute_beta()
_test_soft_sample()
_test_knowledge_base_lookup_autocast()
_test_knowledge_base_lookup()

View File

@ -18,6 +18,7 @@
import collections
from itertools import repeat
from typing import Optional, Tuple
from torch.cuda.amp import custom_fwd, custom_bwd
import torch
import torch.nn as nn
@ -39,6 +40,7 @@ _pair = _ntuple(2)
class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(
ctx,
x: Tensor,
@ -85,6 +87,7 @@ class ActivationBalancerFunction(torch.autograd.Function):
return x
@staticmethod
@custom_bwd
def backward(
ctx, x_grad: Tensor
) -> Tuple[Tensor, None, None, None, None, None, None]:
@ -426,6 +429,7 @@ class DoubleSwishFunction(torch.autograd.Function):
"""
@staticmethod
@custom_fwd
def forward(ctx, x: Tensor) -> Tensor:
x = x.detach()
s = torch.sigmoid(x - 1.0)
@ -434,6 +438,7 @@ class DoubleSwishFunction(torch.autograd.Function):
return y
@staticmethod
@custom_bwd
def backward(ctx, y_grad: Tensor) -> Tensor:
s, y = ctx.saved_tensors
return (y * (1 - s) + s) * y_grad