mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
Fix for half precision
This commit is contained in:
parent
e718c7ac88
commit
2c4478b6d1
@ -7,6 +7,7 @@ import timeit
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from torch.cuda.amp import GradScaler
|
||||
from typing import Tuple, Optional
|
||||
from scaling import ScaledLinear
|
||||
import random
|
||||
@ -565,7 +566,7 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens
|
||||
|
||||
# the + 1 is because we need all elements of P to be nonzero (this will avoid
|
||||
# some nasty edge cases)
|
||||
P = (p * (2**(num_bits_per_sample)) + 1).to(dtype=torch.int64)
|
||||
P = (p.to(torch.float32) * (2**(num_bits_per_sample)) + 1).to(dtype=torch.int64)
|
||||
values, indexes = compute_k_largest(P, K)
|
||||
prod_values, prod_indexes = compute_products(values, indexes)
|
||||
|
||||
@ -823,7 +824,7 @@ def soft_sample_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor,
|
||||
i = (i * s) % M # Reverse the pseudo-random reordering
|
||||
y = torch.maximum(torch.gather(p, dim=-1, index=i), beta)
|
||||
assert torch.all(is_ok.sum(dim=-1) == K)
|
||||
assert torch.all((y.sum(dim=-1) - 1.0).abs() < 0.01)
|
||||
assert torch.all((y.sum(dim=-1) - 1.0).abs() < 0.1)
|
||||
|
||||
|
||||
|
||||
@ -961,12 +962,14 @@ class KnowledgeBaseLookup(nn.Module):
|
||||
|
||||
# TODO: later we can try multiplying by a projection of x or something like that.
|
||||
"""
|
||||
assert torch.all(x - x == 0)
|
||||
x = self.in_proj(x) # now (*, M*N)
|
||||
assert torch.all(x - x == 0)
|
||||
x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M)
|
||||
x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M)
|
||||
if random.random() < 0.001:
|
||||
assert torch.all(x - x == 0)
|
||||
if random.random() < 0.001 or x.dtype == torch.float16:
|
||||
entropy = (x * x.exp()).sum(dim=-1).mean()
|
||||
print("Entropy = ", entropy)
|
||||
weights, indexes, = sample_combined(x, self.K, input_is_log=True)
|
||||
indexes = join_indexes(indexes, self.M)
|
||||
x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D)
|
||||
@ -1186,12 +1189,12 @@ def _test_knowledge_base_lookup():
|
||||
print("x.grad = ", x.grad)
|
||||
print("knowlege_base.grad norm = ", knowledge_base.grad.norm())
|
||||
|
||||
|
||||
dtype = torch.float16
|
||||
device = torch.device('cuda')
|
||||
train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ]
|
||||
train_pairs = [ (torch.randn(B, T, E, device=device, dtype=dtype), torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ]
|
||||
from optim import Eve
|
||||
optimizer = Eve(m.parameters(), lr=0.005)
|
||||
m = m.to(device)
|
||||
optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04)
|
||||
m = m.to(device).to(dtype)
|
||||
|
||||
|
||||
start = timeit.default_timer()
|
||||
@ -1212,7 +1215,7 @@ def _test_knowledge_base_lookup():
|
||||
for epoch in range(120):
|
||||
for n, (x,y) in enumerate(train_pairs):
|
||||
y_out = m(x)
|
||||
loss = ((y_out - y)**2).mean()
|
||||
loss = ((y_out - y)**2).mean() * 100.0
|
||||
if n % 10 == 0 and epoch % 10 == 0:
|
||||
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
|
||||
loss.backward()
|
||||
|
Loading…
x
Reference in New Issue
Block a user