mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
Add some training code. Seems to be training successfully...
This commit is contained in:
parent
df39fc6783
commit
f8c7e6ffb3
@ -8,6 +8,7 @@ from torch import Tensor
|
||||
from torch import nn
|
||||
from typing import Tuple, Optional
|
||||
from scaling import ScaledLinear
|
||||
import random
|
||||
|
||||
# The main export of this file is the function sample_combined().
|
||||
|
||||
@ -908,7 +909,8 @@ class KnowledgeBaseLookup(nn.Module):
|
||||
knowledge_base: nn.Parameter):
|
||||
super(KnowledgeBaseLookup, self).__init__()
|
||||
self.knowledge_base = knowledge_base # shared!
|
||||
self.in_proj = ScaledLinear(embedding_dim, M * N)
|
||||
self.in_proj = ScaledLinear(embedding_dim, M * N,
|
||||
initial_scale=5.0)
|
||||
# initial_scale = 4.0 because the knowlege_base activations are
|
||||
# quite small -- if we use our optimizer they'll have stddev <= 0.1.
|
||||
self.out_proj = ScaledLinear(D, embedding_dim,
|
||||
@ -930,6 +932,9 @@ class KnowledgeBaseLookup(nn.Module):
|
||||
x = self.in_proj(x) # now (*, M*N)
|
||||
x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M)
|
||||
x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M)
|
||||
if random.random() < 0.01:
|
||||
entropy = (x * x.exp()).sum(dim=-1).mean()
|
||||
print("Entropy = ", entropy)
|
||||
weights, indexes, = sample_combined(x, self.K, input_is_log=True)
|
||||
indexes = join_indexes(indexes, self.M)
|
||||
x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D)
|
||||
@ -1151,7 +1156,19 @@ def _test_knowledge_base_lookup():
|
||||
print("knowlege_base.grad norm = ", knowledge_base.grad.norm())
|
||||
|
||||
|
||||
#train_pairs = [ (torch.randn(B, T, E), torch.randn(B, T, E)) for _ in range(100) ]
|
||||
train_pairs = [ (torch.randn(B, T, E), torch.randn(B, T, E)) for _ in range(11) ]
|
||||
from optim import Eve
|
||||
optimizer = Eve(m.parameters(), lr=0.005)
|
||||
|
||||
for epoch in range(100):
|
||||
for n, (x,y) in enumerate(train_pairs):
|
||||
y_out = m(x)
|
||||
loss = ((y_out - y)**2).mean()
|
||||
if n % 10 == 0:
|
||||
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user