Fix for half precision

This commit is contained in:
Daniel Povey 2022-04-25 23:03:34 +08:00
parent e718c7ac88
commit 2c4478b6d1

View File

@ -7,6 +7,7 @@ import timeit
import torch
from torch import Tensor
from torch import nn
from torch.cuda.amp import GradScaler
from typing import Tuple, Optional
from scaling import ScaledLinear
import random
@ -565,7 +566,7 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens
# the + 1 is because we need all elements of P to be nonzero (this will avoid
# some nasty edge cases)
P = (p * (2**(num_bits_per_sample)) + 1).to(dtype=torch.int64)
P = (p.to(torch.float32) * (2**(num_bits_per_sample)) + 1).to(dtype=torch.int64)
values, indexes = compute_k_largest(P, K)
prod_values, prod_indexes = compute_products(values, indexes)
@ -823,7 +824,7 @@ def soft_sample_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor,
i = (i * s) % M # Reverse the pseudo-random reordering
y = torch.maximum(torch.gather(p, dim=-1, index=i), beta)
assert torch.all(is_ok.sum(dim=-1) == K)
assert torch.all((y.sum(dim=-1) - 1.0).abs() < 0.01)
assert torch.all((y.sum(dim=-1) - 1.0).abs() < 0.1)
@ -961,12 +962,14 @@ class KnowledgeBaseLookup(nn.Module):
# TODO: later we can try multiplying by a projection of x or something like that.
"""
assert torch.all(x - x == 0)
x = self.in_proj(x) # now (*, M*N)
assert torch.all(x - x == 0)
x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M)
x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M)
if random.random() < 0.001:
assert torch.all(x - x == 0)
if random.random() < 0.001 or x.dtype == torch.float16:
entropy = (x * x.exp()).sum(dim=-1).mean()
print("Entropy = ", entropy)
weights, indexes, = sample_combined(x, self.K, input_is_log=True)
indexes = join_indexes(indexes, self.M)
x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D)
@ -1186,12 +1189,12 @@ def _test_knowledge_base_lookup():
print("x.grad = ", x.grad)
print("knowlege_base.grad norm = ", knowledge_base.grad.norm())
dtype = torch.float16
device = torch.device('cuda')
train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ]
train_pairs = [ (torch.randn(B, T, E, device=device, dtype=dtype), torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ]
from optim import Eve
optimizer = Eve(m.parameters(), lr=0.005)
m = m.to(device)
optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04)
m = m.to(device).to(dtype)
start = timeit.default_timer()
@ -1212,7 +1215,7 @@ def _test_knowledge_base_lookup():
for epoch in range(120):
for n, (x,y) in enumerate(train_pairs):
y_out = m(x)
loss = ((y_out - y)**2).mean()
loss = ((y_out - y)**2).mean() * 100.0
if n % 10 == 0 and epoch % 10 == 0:
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
loss.backward()