diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index b9c2703b4..9e4c66e7c 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -7,6 +7,7 @@ import timeit import torch from torch import Tensor from torch import nn +from torch.cuda.amp import GradScaler from typing import Tuple, Optional from scaling import ScaledLinear import random @@ -565,7 +566,7 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens # the + 1 is because we need all elements of P to be nonzero (this will avoid # some nasty edge cases) - P = (p * (2**(num_bits_per_sample)) + 1).to(dtype=torch.int64) + P = (p.to(torch.float32) * (2**(num_bits_per_sample)) + 1).to(dtype=torch.int64) values, indexes = compute_k_largest(P, K) prod_values, prod_indexes = compute_products(values, indexes) @@ -823,7 +824,7 @@ def soft_sample_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, i = (i * s) % M # Reverse the pseudo-random reordering y = torch.maximum(torch.gather(p, dim=-1, index=i), beta) assert torch.all(is_ok.sum(dim=-1) == K) - assert torch.all((y.sum(dim=-1) - 1.0).abs() < 0.01) + assert torch.all((y.sum(dim=-1) - 1.0).abs() < 0.1) @@ -961,12 +962,14 @@ class KnowledgeBaseLookup(nn.Module): # TODO: later we can try multiplying by a projection of x or something like that. """ + assert torch.all(x - x == 0) x = self.in_proj(x) # now (*, M*N) + assert torch.all(x - x == 0) x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) - if random.random() < 0.001: + assert torch.all(x - x == 0) + if random.random() < 0.001 or x.dtype == torch.float16: entropy = (x * x.exp()).sum(dim=-1).mean() - print("Entropy = ", entropy) weights, indexes, = sample_combined(x, self.K, input_is_log=True) indexes = join_indexes(indexes, self.M) x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D) @@ -1186,12 +1189,12 @@ def _test_knowledge_base_lookup(): print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) - + dtype = torch.float16 device = torch.device('cuda') - train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ] + train_pairs = [ (torch.randn(B, T, E, device=device, dtype=dtype), torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ] from optim import Eve - optimizer = Eve(m.parameters(), lr=0.005) - m = m.to(device) + optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) + m = m.to(device).to(dtype) start = timeit.default_timer() @@ -1212,7 +1215,7 @@ def _test_knowledge_base_lookup(): for epoch in range(120): for n, (x,y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y)**2).mean() + loss = ((y_out - y)**2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") loss.backward()