From 0bf538a4a37ff47bd95fccd5bff2783ca9a1c21d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 May 2022 13:20:10 +0800 Subject: [PATCH] Add negentropy_penalty, on individual dims. --- .../ASR/pruned2_knowledge/sampling.py | 53 ++++++++++++++++--- 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index a2b7b3a9e..a9e0a1848 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -935,6 +935,48 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype) +class PenalizeNegentropyFunction(torch.autograd.Function): + """ + Function that does nothing in forward pass, but in backprop, it is as + if you had added: `- tot_entropy * alpha` to the loss function, where + tot_entropy is the the entropy of the average of the input distributions, + times the number of input distributions. (We multiply by this because + our overall loss function is proportional to the number of frames). + + This will tend to make the entropy want to become as large as possible, + making (-tot_entropy * alpha) as negative as possible. + + Args: + logprobs: Tensor of shape (*, num_classes), should be the result of + calling some_tensor.log_softmax(dim=-1) + Returns: + logprobs + """ + @staticmethod + def forward(ctx, logprobs: Tensor, alpha: float): + ctx.save_for_backward(logprobs.detach()) + ctx.alpha = alpha + return logprobs + + @staticmethod + def backward(ctx, logprobs_grad: Tensor) -> Tuple[Tensor, None]: + logprobs, = ctx.saved_tensors + with torch.enable_grad(): + logprobs.requires_grad = True + # `negentropy` is the negative entropy of the average distribution. + # distributions. It will be <= 0. + l = logprobs.reshape(-1, logprobs.shape[-1]) + scale = ctx.alpha * l.shape[0] + avg_dist = l.exp().mean(dim=0) + negentropy = (avg_dist * (avg_dist + 1.0e-20).log()).sum() + if random.random() < 0.0005: + negentropy_individual = (l * l.exp()).sum(dim=-1).mean() + print("Negentropy[individual,combined] = ", negentropy_individual.item(), ", ", negentropy.item()) + loss = negentropy * scale + loss.backward() + return logprobs_grad + logprobs.grad, None + + class KnowledgeBaseLookup(nn.Module): """ Create knowledge-base lookup module. (The knowledge-base parameter, which is @@ -949,7 +991,8 @@ class KnowledgeBaseLookup(nn.Module): """ def __init__(self, M: int, N: int, D: int, K: int, embedding_dim: int, - knowledge_base: nn.Parameter): + knowledge_base: nn.Parameter, + negentropy_penalty: float = 0.001): super(KnowledgeBaseLookup, self).__init__() self.knowledge_base = knowledge_base # shared! self.in_proj = ScaledLinear(embedding_dim, M * N, @@ -961,6 +1004,7 @@ class KnowledgeBaseLookup(nn.Module): self.M = M self.N = N self.K = K + self.negentropy_penalty = negentropy_penalty def forward(self, x: Tensor) -> Tensor: """ @@ -972,15 +1016,10 @@ class KnowledgeBaseLookup(nn.Module): # TODO: later we can try multiplying by a projection of x or something like that. """ - assert torch.all(x - x == 0) x = self.in_proj(x) # now (*, M*N) - assert torch.all(x - x == 0) x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) - assert torch.all(x - x == 0) - if random.random() < 0.001: - entropy = (x * x.exp()).sum(dim=-1).mean() - print("Entropy = ", entropy) + x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty) weights, indexes, = sample_combined(x, self.K, input_is_log=True) indexes = join_indexes(indexes, self.M) x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D)