diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 4871e5981..7b05e2f00 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -214,9 +214,9 @@ class KnowledgeBaseLookup(nn.Module): x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty) - weights, indexes, = sample_combined(x, self.K, input_is_log=True) - indexes = join_indexes(indexes, self.M) - x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D) + + _, indexes, weights = sample_combined(x, self.K, input_is_log=True) + x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) x = self.out_proj(x) # now (*, self.embedding_dim) return x