mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
Add some training code. Seems to be training successfully...
This commit is contained in:
parent
df39fc6783
commit
f8c7e6ffb3
@ -8,6 +8,7 @@ from torch import Tensor
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
import random
|
||||||
|
|
||||||
# The main export of this file is the function sample_combined().
|
# The main export of this file is the function sample_combined().
|
||||||
|
|
||||||
@ -908,7 +909,8 @@ class KnowledgeBaseLookup(nn.Module):
|
|||||||
knowledge_base: nn.Parameter):
|
knowledge_base: nn.Parameter):
|
||||||
super(KnowledgeBaseLookup, self).__init__()
|
super(KnowledgeBaseLookup, self).__init__()
|
||||||
self.knowledge_base = knowledge_base # shared!
|
self.knowledge_base = knowledge_base # shared!
|
||||||
self.in_proj = ScaledLinear(embedding_dim, M * N)
|
self.in_proj = ScaledLinear(embedding_dim, M * N,
|
||||||
|
initial_scale=5.0)
|
||||||
# initial_scale = 4.0 because the knowlege_base activations are
|
# initial_scale = 4.0 because the knowlege_base activations are
|
||||||
# quite small -- if we use our optimizer they'll have stddev <= 0.1.
|
# quite small -- if we use our optimizer they'll have stddev <= 0.1.
|
||||||
self.out_proj = ScaledLinear(D, embedding_dim,
|
self.out_proj = ScaledLinear(D, embedding_dim,
|
||||||
@ -930,6 +932,9 @@ class KnowledgeBaseLookup(nn.Module):
|
|||||||
x = self.in_proj(x) # now (*, M*N)
|
x = self.in_proj(x) # now (*, M*N)
|
||||||
x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M)
|
x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M)
|
||||||
x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M)
|
x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M)
|
||||||
|
if random.random() < 0.01:
|
||||||
|
entropy = (x * x.exp()).sum(dim=-1).mean()
|
||||||
|
print("Entropy = ", entropy)
|
||||||
weights, indexes, = sample_combined(x, self.K, input_is_log=True)
|
weights, indexes, = sample_combined(x, self.K, input_is_log=True)
|
||||||
indexes = join_indexes(indexes, self.M)
|
indexes = join_indexes(indexes, self.M)
|
||||||
x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D)
|
x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D)
|
||||||
@ -1151,7 +1156,19 @@ def _test_knowledge_base_lookup():
|
|||||||
print("knowlege_base.grad norm = ", knowledge_base.grad.norm())
|
print("knowlege_base.grad norm = ", knowledge_base.grad.norm())
|
||||||
|
|
||||||
|
|
||||||
#train_pairs = [ (torch.randn(B, T, E), torch.randn(B, T, E)) for _ in range(100) ]
|
train_pairs = [ (torch.randn(B, T, E), torch.randn(B, T, E)) for _ in range(11) ]
|
||||||
|
from optim import Eve
|
||||||
|
optimizer = Eve(m.parameters(), lr=0.005)
|
||||||
|
|
||||||
|
for epoch in range(100):
|
||||||
|
for n, (x,y) in enumerate(train_pairs):
|
||||||
|
y_out = m(x)
|
||||||
|
loss = ((y_out - y)**2).mean()
|
||||||
|
if n % 10 == 0:
|
||||||
|
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
x
Reference in New Issue
Block a user