#!/usr/bin/env python3 # This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py, # its git history is there. import timeit import torch from torch import Tensor from torch import nn from torch.cuda.amp import GradScaler, custom_fwd, custom_bwd from typing import Tuple, Optional from scaling import ScaledLinear import random from torch_scheduled_sampling import sample_combined # The main exports of this file are the module KnowledgeBaseLookup and the # function create_knowledge_base. def create_knowledge_base(M: int, N: int, D: int) -> nn.Parameter: std = 0.1 a = (3 ** 0.5) * std # this sqrt(3) thing is intended to get variance of # 0.1 from uniform distribution ans = nn.Parameter(torch.ones(M ** N, D)) nn.init.uniform_(ans, -a, a) return ans def join_indexes(indexes: Tensor, M: int) -> Tensor: """ Combines N-tuples of indexes into single indexes that can be used for lookup in the knowledge base. Args: indexes: tensor of torch.int64 of shape (*, K, N), with elements in {0..M-1} M: the size of the original softmaxes, is upper bound on elements in indexes Returns: joined_indexes: of shape (*, K), joined_indexes[...,k] equals joined_indexes[...,0,k] + joined_indexes[...,1,k]*(M**1) ... + joined_indexes[...,1,k]*(M**(N-1))] """ N = indexes.shape[-1] n_powers = M ** torch.arange(N, device=indexes.device) # [ 1, M, ..., M**(N-1) ] return (indexes * n_powers).sum(dim=-1) # Note, we don't use this, we def weighted_matrix_lookup(weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. indexes: Tensor of shape (*, K), with elements in [0..C-1] knowledge_base: Tensor of shape (C-1, D), whose rows we'll be looking up Returns: tensor of shape (*, D), containing weighted sums of rows of `knowledge_base` """ if True: return WeightedMatrixLookupFunction.apply(weights, indexes, knowledge_base) else: # simpler but less memory-efficient implementation lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] weights = weights.unsqueeze(-2) # (*, 1, K) lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) ans = torch.matmul(weights, lookup) # ans: (*, 1, D) ans = ans.squeeze(-2) assert list(ans.shape) == list(weights.shape[:-2]) + [D] return ans class WeightedMatrixLookupFunction(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. indexes: Tensor of shape (*, K), with elements in [0..C-1] knowledge_base: Tensor of shape (C, D), whose rows we'll be looking up Returns: tensor of shape (*, D), containing weighted sums of rows of `knowledge_base` """ # if random.random() < 0.001: # print("dtype[1] = ", weights.dtype) ctx.save_for_backward(weights.detach(), indexes.detach(), knowledge_base.detach()) with torch.no_grad(): lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] weights = weights.unsqueeze(-2) # (*, 1, K) lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) ans = torch.matmul(weights, lookup) # ans: (*, 1, D) ans = ans.squeeze(-2) #(*, D) return ans @staticmethod @custom_bwd def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, Tensor]: # ans_grad: (*, D) weights, indexes, knowledge_base = ctx.saved_tensors knowledge_base.requires_grad = True dtype = ans_grad.dtype ans_grad = ans_grad.to(weights.dtype) assert weights.requires_grad == False D = knowledge_base.shape[-1] with torch.enable_grad(): # we'll use torch's autograd to differentiate this operation, which # is nontrivial [and anyway we need `lookup` to compute weight grad. # We don't save `lookup` because it's large, that is the reason # we override Torch autograd. lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) weights = weights.unsqueeze(-1) # (*, K, 1) # forward pass: was: ## ans = torch.matmul(weights, lookup) ## ans: (*, 1, D) ## ans = ans.squeeze(-2) # ans, ans_grad: (*, D) weights_grad = torch.matmul(lookup, # (*, K, D) ans_grad.unsqueeze(-1)) # (*, D, 1) weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D) lookup.backward(gradient=lookup_grad) return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype) class KnowledgeBaseLookup(nn.Module): """ Create knowledge-base lookup module. (The knowledge-base parameter, which is large, is shared between these modules). Args: M: int, softmax size, e.g. in [32..128] N: int, number of softmaxes, in [2..3] D: int, embedding dimension in knowledge base, e.g. 256 K: number of samples (affects speed/accuracy tradeoff), e.g. 16. embedding_dim: the dimension to project from and to, e.g. the d_model of the conformer. """ def __init__(self, M: int, N: int, D: int, K: int, embedding_dim: int, knowledge_base: nn.Parameter): super(KnowledgeBaseLookup, self).__init__() self.knowledge_base = knowledge_base # shared! self.in_proj = ScaledLinear(embedding_dim, M * N, initial_scale=1.0) # initial_scale = 4.0 because the knowlege_base activations are # quite small -- if we use our optimizer they'll have stddev <= 0.1. self.out_proj = ScaledLinear(D, embedding_dim, initial_scale = 4.0) self.M = M self.N = N self.K = K def forward(self, x: Tensor) -> Tensor: """ Forward function that does knowledge-base lookup. Args: x: input, of shape (*, E) where E is embedding_dim as passed to constructor y: output of knowledge-base lookup, of shape (*, E) # TODO: later we can try multiplying by a projection of x or something like that. """ assert torch.all(x - x == 0) x = self.in_proj(x) # now (*, M*N) assert torch.all(x - x == 0) x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) assert torch.all(x - x == 0) if random.random() < 0.001: entropy = (x * x.exp()).sum(dim=-1).mean() # print("Entropy = ", entropy) # only need 'combined_indexes', call them 'indexes'. _, indexes, weights = sample_combined(x, self.K, input_is_log=True) x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) x = self.out_proj(x) # now (*, self.embedding_dim) return x def _test_knowledge_base_lookup(): K = 16 N = 2 M = 128 D = 256 E = 255 knowledge_base: nn.Parameter = create_knowledge_base(M, N, D) m = KnowledgeBaseLookup(M, N, D, K, E, knowledge_base) B = 30 T = 40 x = torch.randn(B, T, E) x.requires_grad = True y = m(x) assert y.shape == x.shape y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) dtype = torch.float32 device = torch.device('cuda') train_pairs = [ (torch.randn(B, T, E, device=device, dtype=dtype), torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ] from optim import Eve optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device).to(dtype) start = timeit.default_timer() # Epoch 0, batch 0, loss 1.0109944343566895 # Epoch 10, batch 0, loss 1.0146660804748535 # Epoch 20, batch 0, loss 1.0119813680648804 # Epoch 30, batch 0, loss 1.0105408430099487 # Epoch 40, batch 0, loss 1.0077732801437378 # Epoch 50, batch 0, loss 1.0050103664398193 # Epoch 60, batch 0, loss 1.0033129453659058 # Epoch 70, batch 0, loss 1.0014232397079468 # Epoch 80, batch 0, loss 0.9977912306785583 # Epoch 90, batch 0, loss 0.8274348974227905 # Epoch 100, batch 0, loss 0.3368612825870514 # Epoch 110, batch 0, loss 0.11323091387748718 # Time taken: 17.591704960912466 for epoch in range(150): for n, (x,y) in enumerate(train_pairs): y_out = m(x) loss = ((y_out - y)**2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") loss.backward() optimizer.step() optimizer.zero_grad() stop = timeit.default_timer() print('Time taken: ', stop - start) def _test_knowledge_base_lookup_autocast(): K = 16 N = 2 M = 128 D = 256 E = 255 knowledge_base: nn.Parameter = create_knowledge_base(M, N, D) m = KnowledgeBaseLookup(M, N, D, K, E, knowledge_base) B = 30 T = 40 x = torch.randn(B, T, E) x.requires_grad = True y = m(x) assert y.shape == x.shape y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) device = torch.device('cuda') train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ] from optim import Eve optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device) scaler = GradScaler(enabled=True) start = timeit.default_timer() for epoch in range(150): for n, (x,y) in enumerate(train_pairs): y_out = m(x) with torch.cuda.amp.autocast(enabled=True): loss = ((y_out - y)**2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad() stop = timeit.default_timer() print('Time taken: ', stop - start) if __name__ == '__main__': _test_knowledge_base_lookup() _test_knowledge_base_lookup_autocast()