From aea116ea25ceea5098dfb2f9c561fca71fbac43e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 Apr 2022 14:02:43 +0800 Subject: [PATCH] Change printing-prob, initial scales --- egs/librispeech/ASR/pruned2_knowledge/sampling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 02cac6748..fa9502d20 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -916,11 +916,11 @@ class KnowledgeBaseLookup(nn.Module): super(KnowledgeBaseLookup, self).__init__() self.knowledge_base = knowledge_base # shared! self.in_proj = ScaledLinear(embedding_dim, M * N, - initial_scale=5.0) + initial_scale=1.0) # initial_scale = 4.0 because the knowlege_base activations are # quite small -- if we use our optimizer they'll have stddev <= 0.1. self.out_proj = ScaledLinear(D, embedding_dim, - initial_scale = 10.0) + initial_scale = 4.0) self.M = M self.N = N self.K = K @@ -938,7 +938,7 @@ class KnowledgeBaseLookup(nn.Module): x = self.in_proj(x) # now (*, M*N) x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) - if random.random() < 0.01: + if random.random() < 0.001: entropy = (x * x.exp()).sum(dim=-1).mean() print("Entropy = ", entropy) weights, indexes, = sample_combined(x, self.K, input_is_log=True)