From f8c7e6ffb36f92ad616ea131062f4dff0e315755 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 24 Apr 2022 23:19:46 +0800 Subject: [PATCH] Add some training code. Seems to be training successfully... --- .../ASR/pruned2_knowledge/sampling.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index bfa3d0768..e68ea153b 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -8,6 +8,7 @@ from torch import Tensor from torch import nn from typing import Tuple, Optional from scaling import ScaledLinear +import random # The main export of this file is the function sample_combined(). @@ -908,7 +909,8 @@ class KnowledgeBaseLookup(nn.Module): knowledge_base: nn.Parameter): super(KnowledgeBaseLookup, self).__init__() self.knowledge_base = knowledge_base # shared! - self.in_proj = ScaledLinear(embedding_dim, M * N) + self.in_proj = ScaledLinear(embedding_dim, M * N, + initial_scale=5.0) # initial_scale = 4.0 because the knowlege_base activations are # quite small -- if we use our optimizer they'll have stddev <= 0.1. self.out_proj = ScaledLinear(D, embedding_dim, @@ -930,6 +932,9 @@ class KnowledgeBaseLookup(nn.Module): x = self.in_proj(x) # now (*, M*N) x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) + if random.random() < 0.01: + entropy = (x * x.exp()).sum(dim=-1).mean() + print("Entropy = ", entropy) weights, indexes, = sample_combined(x, self.K, input_is_log=True) indexes = join_indexes(indexes, self.M) x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D) @@ -1151,7 +1156,19 @@ def _test_knowledge_base_lookup(): print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) - #train_pairs = [ (torch.randn(B, T, E), torch.randn(B, T, E)) for _ in range(100) ] + train_pairs = [ (torch.randn(B, T, E), torch.randn(B, T, E)) for _ in range(11) ] + from optim import Eve + optimizer = Eve(m.parameters(), lr=0.005) + + for epoch in range(100): + for n, (x,y) in enumerate(train_pairs): + y_out = m(x) + loss = ((y_out - y)**2).mean() + if n % 10 == 0: + print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") + loss.backward() + optimizer.step() + optimizer.zero_grad() if __name__ == '__main__':