Add some training code. Seems to be training successfully...

This commit is contained in:
Daniel Povey 2022-04-24 23:19:46 +08:00
parent df39fc6783
commit f8c7e6ffb3

View File

@ -8,6 +8,7 @@ from torch import Tensor
from torch import nn from torch import nn
from typing import Tuple, Optional from typing import Tuple, Optional
from scaling import ScaledLinear from scaling import ScaledLinear
import random
# The main export of this file is the function sample_combined(). # The main export of this file is the function sample_combined().
@ -908,7 +909,8 @@ class KnowledgeBaseLookup(nn.Module):
knowledge_base: nn.Parameter): knowledge_base: nn.Parameter):
super(KnowledgeBaseLookup, self).__init__() super(KnowledgeBaseLookup, self).__init__()
self.knowledge_base = knowledge_base # shared! self.knowledge_base = knowledge_base # shared!
self.in_proj = ScaledLinear(embedding_dim, M * N) self.in_proj = ScaledLinear(embedding_dim, M * N,
initial_scale=5.0)
# initial_scale = 4.0 because the knowlege_base activations are # initial_scale = 4.0 because the knowlege_base activations are
# quite small -- if we use our optimizer they'll have stddev <= 0.1. # quite small -- if we use our optimizer they'll have stddev <= 0.1.
self.out_proj = ScaledLinear(D, embedding_dim, self.out_proj = ScaledLinear(D, embedding_dim,
@ -930,6 +932,9 @@ class KnowledgeBaseLookup(nn.Module):
x = self.in_proj(x) # now (*, M*N) x = self.in_proj(x) # now (*, M*N)
x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M)
x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M)
if random.random() < 0.01:
entropy = (x * x.exp()).sum(dim=-1).mean()
print("Entropy = ", entropy)
weights, indexes, = sample_combined(x, self.K, input_is_log=True) weights, indexes, = sample_combined(x, self.K, input_is_log=True)
indexes = join_indexes(indexes, self.M) indexes = join_indexes(indexes, self.M)
x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D) x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D)
@ -1151,7 +1156,19 @@ def _test_knowledge_base_lookup():
print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) print("knowlege_base.grad norm = ", knowledge_base.grad.norm())
#train_pairs = [ (torch.randn(B, T, E), torch.randn(B, T, E)) for _ in range(100) ] train_pairs = [ (torch.randn(B, T, E), torch.randn(B, T, E)) for _ in range(11) ]
from optim import Eve
optimizer = Eve(m.parameters(), lr=0.005)
for epoch in range(100):
for n, (x,y) in enumerate(train_pairs):
y_out = m(x)
loss = ((y_out - y)**2).mean()
if n % 10 == 0:
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
loss.backward()
optimizer.step()
optimizer.zero_grad()
if __name__ == '__main__': if __name__ == '__main__':