Add negentropy_penalty, on individual dims.

This commit is contained in:
Daniel Povey 2022-05-10 13:20:10 +08:00
parent 551786b9bd
commit 0bf538a4a3

View File

@ -935,6 +935,48 @@ class WeightedMatrixLookupFunction(torch.autograd.Function):
return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype)
class PenalizeNegentropyFunction(torch.autograd.Function):
"""
Function that does nothing in forward pass, but in backprop, it is as
if you had added: `- tot_entropy * alpha` to the loss function, where
tot_entropy is the the entropy of the average of the input distributions,
times the number of input distributions. (We multiply by this because
our overall loss function is proportional to the number of frames).
This will tend to make the entropy want to become as large as possible,
making (-tot_entropy * alpha) as negative as possible.
Args:
logprobs: Tensor of shape (*, num_classes), should be the result of
calling some_tensor.log_softmax(dim=-1)
Returns:
logprobs
"""
@staticmethod
def forward(ctx, logprobs: Tensor, alpha: float):
ctx.save_for_backward(logprobs.detach())
ctx.alpha = alpha
return logprobs
@staticmethod
def backward(ctx, logprobs_grad: Tensor) -> Tuple[Tensor, None]:
logprobs, = ctx.saved_tensors
with torch.enable_grad():
logprobs.requires_grad = True
# `negentropy` is the negative entropy of the average distribution.
# distributions. It will be <= 0.
l = logprobs.reshape(-1, logprobs.shape[-1])
scale = ctx.alpha * l.shape[0]
avg_dist = l.exp().mean(dim=0)
negentropy = (avg_dist * (avg_dist + 1.0e-20).log()).sum()
if random.random() < 0.0005:
negentropy_individual = (l * l.exp()).sum(dim=-1).mean()
print("Negentropy[individual,combined] = ", negentropy_individual.item(), ", ", negentropy.item())
loss = negentropy * scale
loss.backward()
return logprobs_grad + logprobs.grad, None
class KnowledgeBaseLookup(nn.Module):
"""
Create knowledge-base lookup module. (The knowledge-base parameter, which is
@ -949,7 +991,8 @@ class KnowledgeBaseLookup(nn.Module):
"""
def __init__(self, M: int, N: int, D: int,
K: int, embedding_dim: int,
knowledge_base: nn.Parameter):
knowledge_base: nn.Parameter,
negentropy_penalty: float = 0.001):
super(KnowledgeBaseLookup, self).__init__()
self.knowledge_base = knowledge_base # shared!
self.in_proj = ScaledLinear(embedding_dim, M * N,
@ -961,6 +1004,7 @@ class KnowledgeBaseLookup(nn.Module):
self.M = M
self.N = N
self.K = K
self.negentropy_penalty = negentropy_penalty
def forward(self, x: Tensor) -> Tensor:
"""
@ -972,15 +1016,10 @@ class KnowledgeBaseLookup(nn.Module):
# TODO: later we can try multiplying by a projection of x or something like that.
"""
assert torch.all(x - x == 0)
x = self.in_proj(x) # now (*, M*N)
assert torch.all(x - x == 0)
x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M)
x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M)
assert torch.all(x - x == 0)
if random.random() < 0.001:
entropy = (x * x.exp()).sum(dim=-1).mean()
print("Entropy = ", entropy)
x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty)
weights, indexes, = sample_combined(x, self.K, input_is_log=True)
indexes = join_indexes(indexes, self.M)
x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D)