Fix for half precision

This commit is contained in:
Daniel Povey 2022-04-25 23:03:34 +08:00
parent e718c7ac88
commit 2c4478b6d1

View File

@ -7,6 +7,7 @@ import timeit
import torch import torch
from torch import Tensor from torch import Tensor
from torch import nn from torch import nn
from torch.cuda.amp import GradScaler
from typing import Tuple, Optional from typing import Tuple, Optional
from scaling import ScaledLinear from scaling import ScaledLinear
import random import random
@ -565,7 +566,7 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens
# the + 1 is because we need all elements of P to be nonzero (this will avoid # the + 1 is because we need all elements of P to be nonzero (this will avoid
# some nasty edge cases) # some nasty edge cases)
P = (p * (2**(num_bits_per_sample)) + 1).to(dtype=torch.int64) P = (p.to(torch.float32) * (2**(num_bits_per_sample)) + 1).to(dtype=torch.int64)
values, indexes = compute_k_largest(P, K) values, indexes = compute_k_largest(P, K)
prod_values, prod_indexes = compute_products(values, indexes) prod_values, prod_indexes = compute_products(values, indexes)
@ -823,7 +824,7 @@ def soft_sample_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor,
i = (i * s) % M # Reverse the pseudo-random reordering i = (i * s) % M # Reverse the pseudo-random reordering
y = torch.maximum(torch.gather(p, dim=-1, index=i), beta) y = torch.maximum(torch.gather(p, dim=-1, index=i), beta)
assert torch.all(is_ok.sum(dim=-1) == K) assert torch.all(is_ok.sum(dim=-1) == K)
assert torch.all((y.sum(dim=-1) - 1.0).abs() < 0.01) assert torch.all((y.sum(dim=-1) - 1.0).abs() < 0.1)
@ -961,12 +962,14 @@ class KnowledgeBaseLookup(nn.Module):
# TODO: later we can try multiplying by a projection of x or something like that. # TODO: later we can try multiplying by a projection of x or something like that.
""" """
assert torch.all(x - x == 0)
x = self.in_proj(x) # now (*, M*N) x = self.in_proj(x) # now (*, M*N)
assert torch.all(x - x == 0)
x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M)
x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M)
if random.random() < 0.001: assert torch.all(x - x == 0)
if random.random() < 0.001 or x.dtype == torch.float16:
entropy = (x * x.exp()).sum(dim=-1).mean() entropy = (x * x.exp()).sum(dim=-1).mean()
print("Entropy = ", entropy)
weights, indexes, = sample_combined(x, self.K, input_is_log=True) weights, indexes, = sample_combined(x, self.K, input_is_log=True)
indexes = join_indexes(indexes, self.M) indexes = join_indexes(indexes, self.M)
x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D) x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D)
@ -1186,12 +1189,12 @@ def _test_knowledge_base_lookup():
print("x.grad = ", x.grad) print("x.grad = ", x.grad)
print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) print("knowlege_base.grad norm = ", knowledge_base.grad.norm())
dtype = torch.float16
device = torch.device('cuda') device = torch.device('cuda')
train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ] train_pairs = [ (torch.randn(B, T, E, device=device, dtype=dtype), torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ]
from optim import Eve from optim import Eve
optimizer = Eve(m.parameters(), lr=0.005) optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04)
m = m.to(device) m = m.to(device).to(dtype)
start = timeit.default_timer() start = timeit.default_timer()
@ -1212,7 +1215,7 @@ def _test_knowledge_base_lookup():
for epoch in range(120): for epoch in range(120):
for n, (x,y) in enumerate(train_pairs): for n, (x,y) in enumerate(train_pairs):
y_out = m(x) y_out = m(x)
loss = ((y_out - y)**2).mean() loss = ((y_out - y)**2).mean() * 100.0
if n % 10 == 0 and epoch % 10 == 0: if n % 10 == 0 and epoch % 10 == 0:
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
loss.backward() loss.backward()