Add some training code. Seems to be training successfully...

This commit is contained in:
Daniel Povey 2022-04-24 23:19:46 +08:00
parent df39fc6783
commit f8c7e6ffb3

View File

@ -8,6 +8,7 @@ from torch import Tensor
from torch import nn
from typing import Tuple, Optional
from scaling import ScaledLinear
import random
# The main export of this file is the function sample_combined().
@ -908,7 +909,8 @@ class KnowledgeBaseLookup(nn.Module):
knowledge_base: nn.Parameter):
super(KnowledgeBaseLookup, self).__init__()
self.knowledge_base = knowledge_base # shared!
self.in_proj = ScaledLinear(embedding_dim, M * N)
self.in_proj = ScaledLinear(embedding_dim, M * N,
initial_scale=5.0)
# initial_scale = 4.0 because the knowlege_base activations are
# quite small -- if we use our optimizer they'll have stddev <= 0.1.
self.out_proj = ScaledLinear(D, embedding_dim,
@ -930,6 +932,9 @@ class KnowledgeBaseLookup(nn.Module):
x = self.in_proj(x) # now (*, M*N)
x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M)
x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M)
if random.random() < 0.01:
entropy = (x * x.exp()).sum(dim=-1).mean()
print("Entropy = ", entropy)
weights, indexes, = sample_combined(x, self.K, input_is_log=True)
indexes = join_indexes(indexes, self.M)
x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D)
@ -1151,7 +1156,19 @@ def _test_knowledge_base_lookup():
print("knowlege_base.grad norm = ", knowledge_base.grad.norm())
#train_pairs = [ (torch.randn(B, T, E), torch.randn(B, T, E)) for _ in range(100) ]
train_pairs = [ (torch.randn(B, T, E), torch.randn(B, T, E)) for _ in range(11) ]
from optim import Eve
optimizer = Eve(m.parameters(), lr=0.005)
for epoch in range(100):
for n, (x,y) in enumerate(train_pairs):
y_out = m(x)
loss = ((y_out - y)**2).mean()
if n % 10 == 0:
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
loss.backward()
optimizer.step()
optimizer.zero_grad()
if __name__ == '__main__':