mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 07:04:18 +00:00
Add negentropy_penalty, on individual dims.
This commit is contained in:
parent
551786b9bd
commit
0bf538a4a3
@ -935,6 +935,48 @@ class WeightedMatrixLookupFunction(torch.autograd.Function):
|
||||
return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype)
|
||||
|
||||
|
||||
class PenalizeNegentropyFunction(torch.autograd.Function):
|
||||
"""
|
||||
Function that does nothing in forward pass, but in backprop, it is as
|
||||
if you had added: `- tot_entropy * alpha` to the loss function, where
|
||||
tot_entropy is the the entropy of the average of the input distributions,
|
||||
times the number of input distributions. (We multiply by this because
|
||||
our overall loss function is proportional to the number of frames).
|
||||
|
||||
This will tend to make the entropy want to become as large as possible,
|
||||
making (-tot_entropy * alpha) as negative as possible.
|
||||
|
||||
Args:
|
||||
logprobs: Tensor of shape (*, num_classes), should be the result of
|
||||
calling some_tensor.log_softmax(dim=-1)
|
||||
Returns:
|
||||
logprobs
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, logprobs: Tensor, alpha: float):
|
||||
ctx.save_for_backward(logprobs.detach())
|
||||
ctx.alpha = alpha
|
||||
return logprobs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, logprobs_grad: Tensor) -> Tuple[Tensor, None]:
|
||||
logprobs, = ctx.saved_tensors
|
||||
with torch.enable_grad():
|
||||
logprobs.requires_grad = True
|
||||
# `negentropy` is the negative entropy of the average distribution.
|
||||
# distributions. It will be <= 0.
|
||||
l = logprobs.reshape(-1, logprobs.shape[-1])
|
||||
scale = ctx.alpha * l.shape[0]
|
||||
avg_dist = l.exp().mean(dim=0)
|
||||
negentropy = (avg_dist * (avg_dist + 1.0e-20).log()).sum()
|
||||
if random.random() < 0.0005:
|
||||
negentropy_individual = (l * l.exp()).sum(dim=-1).mean()
|
||||
print("Negentropy[individual,combined] = ", negentropy_individual.item(), ", ", negentropy.item())
|
||||
loss = negentropy * scale
|
||||
loss.backward()
|
||||
return logprobs_grad + logprobs.grad, None
|
||||
|
||||
|
||||
class KnowledgeBaseLookup(nn.Module):
|
||||
"""
|
||||
Create knowledge-base lookup module. (The knowledge-base parameter, which is
|
||||
@ -949,7 +991,8 @@ class KnowledgeBaseLookup(nn.Module):
|
||||
"""
|
||||
def __init__(self, M: int, N: int, D: int,
|
||||
K: int, embedding_dim: int,
|
||||
knowledge_base: nn.Parameter):
|
||||
knowledge_base: nn.Parameter,
|
||||
negentropy_penalty: float = 0.001):
|
||||
super(KnowledgeBaseLookup, self).__init__()
|
||||
self.knowledge_base = knowledge_base # shared!
|
||||
self.in_proj = ScaledLinear(embedding_dim, M * N,
|
||||
@ -961,6 +1004,7 @@ class KnowledgeBaseLookup(nn.Module):
|
||||
self.M = M
|
||||
self.N = N
|
||||
self.K = K
|
||||
self.negentropy_penalty = negentropy_penalty
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
@ -972,15 +1016,10 @@ class KnowledgeBaseLookup(nn.Module):
|
||||
|
||||
# TODO: later we can try multiplying by a projection of x or something like that.
|
||||
"""
|
||||
assert torch.all(x - x == 0)
|
||||
x = self.in_proj(x) # now (*, M*N)
|
||||
assert torch.all(x - x == 0)
|
||||
x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M)
|
||||
x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M)
|
||||
assert torch.all(x - x == 0)
|
||||
if random.random() < 0.001:
|
||||
entropy = (x * x.exp()).sum(dim=-1).mean()
|
||||
print("Entropy = ", entropy)
|
||||
x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty)
|
||||
weights, indexes, = sample_combined(x, self.K, input_is_log=True)
|
||||
indexes = join_indexes(indexes, self.M)
|
||||
x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D)
|
||||
|
Loading…
x
Reference in New Issue
Block a user