From 3ba081e6d9822070e2d3f082a5538f1060da95c8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 Apr 2022 23:58:34 +0800 Subject: [PATCH] Add more custom_fwd,custom_bwd' --- .../ASR/pruned2_knowledge/sampling.py | 67 +++++++++++++++++-- .../ASR/pruned2_knowledge/scaling.py | 5 ++ 2 files changed, 68 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 9e4c66e7c..32f191258 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -7,7 +7,7 @@ import timeit import torch from torch import Tensor from torch import nn -from torch.cuda.amp import GradScaler +from torch.cuda.amp import GradScaler, custom_fwd, custom_bwd from typing import Tuple, Optional from scaling import ScaledLinear import random @@ -651,7 +651,7 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens samples) # TODO: could remove the next call - if random.random() < 0.01: + if random.random() < 0.0005: check_shifted_samples(combined_cumsums, delta_P, shifted_samples, P_sum_product) @@ -727,7 +727,10 @@ class SampleCombinedFunction(torch.autograd.Function): # please see sample_combined() or sample_combined_forward() or # sample_combined_backward() for documentation @staticmethod + @custom_fwd def forward(ctx, p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, Tensor]: + if random.random() < 0.0005: + print("dtype[1] = ", p.dtype) with torch.no_grad(): weights, indexes = sample_combined_forward(p, K, input_is_log) ctx.save_for_backward(p, indexes, weights) @@ -735,6 +738,7 @@ class SampleCombinedFunction(torch.autograd.Function): return weights, indexes @staticmethod + @custom_bwd def backward(ctx, weights_grad: Optional[Tensor], indexes_grad: Optional[Tensor]) -> Tuple[Tensor, None, None]: p, indexes, weights = ctx.saved_tensors p_grad = sample_combined_backward(p, ctx.input_is_log, indexes, @@ -877,6 +881,7 @@ def weighted_matrix_lookup(weights: Tensor, class WeightedMatrixLookupFunction(torch.autograd.Function): @staticmethod + @custom_fwd def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: """ Weighted combination of specified rows of a matrix. @@ -887,6 +892,8 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): tensor of shape (*, D), containing weighted sums of rows of `knowledge_base` """ + if random.random() < 0.0005: + print("dtype[1] = ", weights.dtype) ctx.save_for_backward(weights.detach(), indexes.detach(), knowledge_base.detach()) with torch.no_grad(): @@ -899,10 +906,13 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): return ans @staticmethod + @custom_bwd def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, Tensor]: # ans_grad: (*, D) weights, indexes, knowledge_base = ctx.saved_tensors knowledge_base.requires_grad = True + dtype = ans_grad.dtype + ans_grad = ans_grad.to(weights.dtype) assert weights.requires_grad == False D = knowledge_base.shape[-1] with torch.enable_grad(): @@ -922,7 +932,7 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D) lookup.backward(gradient=lookup_grad) - return weights_grad, None, knowledge_base.grad + return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype) class KnowledgeBaseLookup(nn.Module): @@ -968,8 +978,9 @@ class KnowledgeBaseLookup(nn.Module): x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) assert torch.all(x - x == 0) - if random.random() < 0.001 or x.dtype == torch.float16: + if random.random() < 0.001: entropy = (x * x.exp()).sum(dim=-1).mean() + print("Entropy = ", entropy) weights, indexes, = sample_combined(x, self.K, input_is_log=True) indexes = join_indexes(indexes, self.M) x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D) @@ -1225,6 +1236,53 @@ def _test_knowledge_base_lookup(): stop = timeit.default_timer() print('Time taken: ', stop - start) +def _test_knowledge_base_lookup_autocast(): + K = 16 + N = 2 + M = 128 + D = 256 + E = 255 + + knowledge_base: nn.Parameter = create_knowledge_base(M, N, D) + m = KnowledgeBaseLookup(M, N, D, K, E, knowledge_base) + + B = 30 + T = 40 + x = torch.randn(B, T, E) + x.requires_grad = True + y = m(x) + assert y.shape == x.shape + y.sum().backward() # make sure backward doesn't crash.. + print("y = ", y) + print("x.grad = ", x.grad) + print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) + + device = torch.device('cuda') + train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ] + from optim import Eve + optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) + m = m.to(device) + + scaler = GradScaler(enabled=True) + + start = timeit.default_timer() + + + for epoch in range(120): + for n, (x,y) in enumerate(train_pairs): + y_out = m(x) + with torch.cuda.amp.autocast(enabled=True): + loss = ((y_out - y)**2).mean() * 100.0 + if n % 10 == 0 and epoch % 10 == 0: + print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + stop = timeit.default_timer() + print('Time taken: ', stop - start) + if __name__ == '__main__': @@ -1233,4 +1291,5 @@ if __name__ == '__main__': _test_combined() _test_compute_beta() _test_soft_sample() + _test_knowledge_base_lookup_autocast() _test_knowledge_base_lookup() diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling.py b/egs/librispeech/ASR/pruned2_knowledge/scaling.py index f89d2963e..f726c2583 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling.py @@ -18,6 +18,7 @@ import collections from itertools import repeat from typing import Optional, Tuple +from torch.cuda.amp import custom_fwd, custom_bwd import torch import torch.nn as nn @@ -39,6 +40,7 @@ _pair = _ntuple(2) class ActivationBalancerFunction(torch.autograd.Function): @staticmethod + @custom_fwd def forward( ctx, x: Tensor, @@ -85,6 +87,7 @@ class ActivationBalancerFunction(torch.autograd.Function): return x @staticmethod + @custom_bwd def backward( ctx, x_grad: Tensor ) -> Tuple[Tensor, None, None, None, None, None, None]: @@ -426,6 +429,7 @@ class DoubleSwishFunction(torch.autograd.Function): """ @staticmethod + @custom_fwd def forward(ctx, x: Tensor) -> Tensor: x = x.detach() s = torch.sigmoid(x - 1.0) @@ -434,6 +438,7 @@ class DoubleSwishFunction(torch.autograd.Function): return y @staticmethod + @custom_bwd def backward(ctx, y_grad: Tensor) -> Tensor: s, y = ctx.saved_tensors return (y * (1 - s) + s) * y_grad