mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Add negentropy_penalty, on individual dims.
This commit is contained in:
parent
551786b9bd
commit
0bf538a4a3
@ -935,6 +935,48 @@ class WeightedMatrixLookupFunction(torch.autograd.Function):
|
|||||||
return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype)
|
return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class PenalizeNegentropyFunction(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Function that does nothing in forward pass, but in backprop, it is as
|
||||||
|
if you had added: `- tot_entropy * alpha` to the loss function, where
|
||||||
|
tot_entropy is the the entropy of the average of the input distributions,
|
||||||
|
times the number of input distributions. (We multiply by this because
|
||||||
|
our overall loss function is proportional to the number of frames).
|
||||||
|
|
||||||
|
This will tend to make the entropy want to become as large as possible,
|
||||||
|
making (-tot_entropy * alpha) as negative as possible.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logprobs: Tensor of shape (*, num_classes), should be the result of
|
||||||
|
calling some_tensor.log_softmax(dim=-1)
|
||||||
|
Returns:
|
||||||
|
logprobs
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, logprobs: Tensor, alpha: float):
|
||||||
|
ctx.save_for_backward(logprobs.detach())
|
||||||
|
ctx.alpha = alpha
|
||||||
|
return logprobs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, logprobs_grad: Tensor) -> Tuple[Tensor, None]:
|
||||||
|
logprobs, = ctx.saved_tensors
|
||||||
|
with torch.enable_grad():
|
||||||
|
logprobs.requires_grad = True
|
||||||
|
# `negentropy` is the negative entropy of the average distribution.
|
||||||
|
# distributions. It will be <= 0.
|
||||||
|
l = logprobs.reshape(-1, logprobs.shape[-1])
|
||||||
|
scale = ctx.alpha * l.shape[0]
|
||||||
|
avg_dist = l.exp().mean(dim=0)
|
||||||
|
negentropy = (avg_dist * (avg_dist + 1.0e-20).log()).sum()
|
||||||
|
if random.random() < 0.0005:
|
||||||
|
negentropy_individual = (l * l.exp()).sum(dim=-1).mean()
|
||||||
|
print("Negentropy[individual,combined] = ", negentropy_individual.item(), ", ", negentropy.item())
|
||||||
|
loss = negentropy * scale
|
||||||
|
loss.backward()
|
||||||
|
return logprobs_grad + logprobs.grad, None
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeBaseLookup(nn.Module):
|
class KnowledgeBaseLookup(nn.Module):
|
||||||
"""
|
"""
|
||||||
Create knowledge-base lookup module. (The knowledge-base parameter, which is
|
Create knowledge-base lookup module. (The knowledge-base parameter, which is
|
||||||
@ -949,7 +991,8 @@ class KnowledgeBaseLookup(nn.Module):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, M: int, N: int, D: int,
|
def __init__(self, M: int, N: int, D: int,
|
||||||
K: int, embedding_dim: int,
|
K: int, embedding_dim: int,
|
||||||
knowledge_base: nn.Parameter):
|
knowledge_base: nn.Parameter,
|
||||||
|
negentropy_penalty: float = 0.001):
|
||||||
super(KnowledgeBaseLookup, self).__init__()
|
super(KnowledgeBaseLookup, self).__init__()
|
||||||
self.knowledge_base = knowledge_base # shared!
|
self.knowledge_base = knowledge_base # shared!
|
||||||
self.in_proj = ScaledLinear(embedding_dim, M * N,
|
self.in_proj = ScaledLinear(embedding_dim, M * N,
|
||||||
@ -961,6 +1004,7 @@ class KnowledgeBaseLookup(nn.Module):
|
|||||||
self.M = M
|
self.M = M
|
||||||
self.N = N
|
self.N = N
|
||||||
self.K = K
|
self.K = K
|
||||||
|
self.negentropy_penalty = negentropy_penalty
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
@ -972,15 +1016,10 @@ class KnowledgeBaseLookup(nn.Module):
|
|||||||
|
|
||||||
# TODO: later we can try multiplying by a projection of x or something like that.
|
# TODO: later we can try multiplying by a projection of x or something like that.
|
||||||
"""
|
"""
|
||||||
assert torch.all(x - x == 0)
|
|
||||||
x = self.in_proj(x) # now (*, M*N)
|
x = self.in_proj(x) # now (*, M*N)
|
||||||
assert torch.all(x - x == 0)
|
|
||||||
x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M)
|
x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M)
|
||||||
x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M)
|
x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M)
|
||||||
assert torch.all(x - x == 0)
|
x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty)
|
||||||
if random.random() < 0.001:
|
|
||||||
entropy = (x * x.exp()).sum(dim=-1).mean()
|
|
||||||
print("Entropy = ", entropy)
|
|
||||||
weights, indexes, = sample_combined(x, self.K, input_is_log=True)
|
weights, indexes, = sample_combined(x, self.K, input_is_log=True)
|
||||||
indexes = join_indexes(indexes, self.M)
|
indexes = join_indexes(indexes, self.M)
|
||||||
x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D)
|
x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user