mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Add more custom_fwd,custom_bwd'
This commit is contained in:
parent
2c4478b6d1
commit
3ba081e6d9
@ -7,7 +7,7 @@ import timeit
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.cuda.amp import GradScaler, custom_fwd, custom_bwd
|
||||
from typing import Tuple, Optional
|
||||
from scaling import ScaledLinear
|
||||
import random
|
||||
@ -651,7 +651,7 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens
|
||||
samples)
|
||||
|
||||
# TODO: could remove the next call
|
||||
if random.random() < 0.01:
|
||||
if random.random() < 0.0005:
|
||||
check_shifted_samples(combined_cumsums, delta_P,
|
||||
shifted_samples, P_sum_product)
|
||||
|
||||
@ -727,7 +727,10 @@ class SampleCombinedFunction(torch.autograd.Function):
|
||||
# please see sample_combined() or sample_combined_forward() or
|
||||
# sample_combined_backward() for documentation
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, Tensor]:
|
||||
if random.random() < 0.0005:
|
||||
print("dtype[1] = ", p.dtype)
|
||||
with torch.no_grad():
|
||||
weights, indexes = sample_combined_forward(p, K, input_is_log)
|
||||
ctx.save_for_backward(p, indexes, weights)
|
||||
@ -735,6 +738,7 @@ class SampleCombinedFunction(torch.autograd.Function):
|
||||
return weights, indexes
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, weights_grad: Optional[Tensor], indexes_grad: Optional[Tensor]) -> Tuple[Tensor, None, None]:
|
||||
p, indexes, weights = ctx.saved_tensors
|
||||
p_grad = sample_combined_backward(p, ctx.input_is_log, indexes,
|
||||
@ -877,6 +881,7 @@ def weighted_matrix_lookup(weights: Tensor,
|
||||
|
||||
class WeightedMatrixLookupFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor:
|
||||
"""
|
||||
Weighted combination of specified rows of a matrix.
|
||||
@ -887,6 +892,8 @@ class WeightedMatrixLookupFunction(torch.autograd.Function):
|
||||
tensor of shape (*, D), containing weighted sums of rows of
|
||||
`knowledge_base`
|
||||
"""
|
||||
if random.random() < 0.0005:
|
||||
print("dtype[1] = ", weights.dtype)
|
||||
ctx.save_for_backward(weights.detach(), indexes.detach(),
|
||||
knowledge_base.detach())
|
||||
with torch.no_grad():
|
||||
@ -899,10 +906,13 @@ class WeightedMatrixLookupFunction(torch.autograd.Function):
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, Tensor]:
|
||||
# ans_grad: (*, D)
|
||||
weights, indexes, knowledge_base = ctx.saved_tensors
|
||||
knowledge_base.requires_grad = True
|
||||
dtype = ans_grad.dtype
|
||||
ans_grad = ans_grad.to(weights.dtype)
|
||||
assert weights.requires_grad == False
|
||||
D = knowledge_base.shape[-1]
|
||||
with torch.enable_grad():
|
||||
@ -922,7 +932,7 @@ class WeightedMatrixLookupFunction(torch.autograd.Function):
|
||||
weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K)
|
||||
lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D)
|
||||
lookup.backward(gradient=lookup_grad)
|
||||
return weights_grad, None, knowledge_base.grad
|
||||
return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype)
|
||||
|
||||
|
||||
class KnowledgeBaseLookup(nn.Module):
|
||||
@ -968,8 +978,9 @@ class KnowledgeBaseLookup(nn.Module):
|
||||
x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M)
|
||||
x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M)
|
||||
assert torch.all(x - x == 0)
|
||||
if random.random() < 0.001 or x.dtype == torch.float16:
|
||||
if random.random() < 0.001:
|
||||
entropy = (x * x.exp()).sum(dim=-1).mean()
|
||||
print("Entropy = ", entropy)
|
||||
weights, indexes, = sample_combined(x, self.K, input_is_log=True)
|
||||
indexes = join_indexes(indexes, self.M)
|
||||
x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D)
|
||||
@ -1225,6 +1236,53 @@ def _test_knowledge_base_lookup():
|
||||
stop = timeit.default_timer()
|
||||
print('Time taken: ', stop - start)
|
||||
|
||||
def _test_knowledge_base_lookup_autocast():
|
||||
K = 16
|
||||
N = 2
|
||||
M = 128
|
||||
D = 256
|
||||
E = 255
|
||||
|
||||
knowledge_base: nn.Parameter = create_knowledge_base(M, N, D)
|
||||
m = KnowledgeBaseLookup(M, N, D, K, E, knowledge_base)
|
||||
|
||||
B = 30
|
||||
T = 40
|
||||
x = torch.randn(B, T, E)
|
||||
x.requires_grad = True
|
||||
y = m(x)
|
||||
assert y.shape == x.shape
|
||||
y.sum().backward() # make sure backward doesn't crash..
|
||||
print("y = ", y)
|
||||
print("x.grad = ", x.grad)
|
||||
print("knowlege_base.grad norm = ", knowledge_base.grad.norm())
|
||||
|
||||
device = torch.device('cuda')
|
||||
train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ]
|
||||
from optim import Eve
|
||||
optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04)
|
||||
m = m.to(device)
|
||||
|
||||
scaler = GradScaler(enabled=True)
|
||||
|
||||
start = timeit.default_timer()
|
||||
|
||||
|
||||
for epoch in range(120):
|
||||
for n, (x,y) in enumerate(train_pairs):
|
||||
y_out = m(x)
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
loss = ((y_out - y)**2).mean() * 100.0
|
||||
if n % 10 == 0 and epoch % 10 == 0:
|
||||
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
stop = timeit.default_timer()
|
||||
print('Time taken: ', stop - start)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -1233,4 +1291,5 @@ if __name__ == '__main__':
|
||||
_test_combined()
|
||||
_test_compute_beta()
|
||||
_test_soft_sample()
|
||||
_test_knowledge_base_lookup_autocast()
|
||||
_test_knowledge_base_lookup()
|
||||
|
@ -18,6 +18,7 @@
|
||||
import collections
|
||||
from itertools import repeat
|
||||
from typing import Optional, Tuple
|
||||
from torch.cuda.amp import custom_fwd, custom_bwd
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -39,6 +40,7 @@ _pair = _ntuple(2)
|
||||
|
||||
class ActivationBalancerFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(
|
||||
ctx,
|
||||
x: Tensor,
|
||||
@ -85,6 +87,7 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(
|
||||
ctx, x_grad: Tensor
|
||||
) -> Tuple[Tensor, None, None, None, None, None, None]:
|
||||
@ -426,6 +429,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x: Tensor) -> Tensor:
|
||||
x = x.detach()
|
||||
s = torch.sigmoid(x - 1.0)
|
||||
@ -434,6 +438,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||
s, y = ctx.saved_tensors
|
||||
return (y * (1 - s) + s) * y_grad
|
||||
|
Loading…
x
Reference in New Issue
Block a user