Try to resolve merge issues etc

This commit is contained in:
Daniel Povey 2022-05-13 11:32:23 +08:00
parent 4f933f5413
commit 44f4aa5f66

View File

@ -214,9 +214,9 @@ class KnowledgeBaseLookup(nn.Module):
x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M)
x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M)
x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty)
weights, indexes, = sample_combined(x, self.K, input_is_log=True)
indexes = join_indexes(indexes, self.M)
x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D)
_, indexes, weights = sample_combined(x, self.K, input_is_log=True)
x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D)
x = self.out_proj(x) # now (*, self.embedding_dim)
return x