mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Penalize attention-weight entropies above a limit.
This commit is contained in:
parent
1812f6cb28
commit
a780984e6b
@ -805,6 +805,127 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
return self.dropout(pos_emb)
|
||||
|
||||
|
||||
class EntropyPenaltyFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
attn_weights: Tensor,
|
||||
num_heads: int,
|
||||
entropy_limit: float,
|
||||
grad_scale: float) -> Tensor:
|
||||
logging.info("Here3")
|
||||
ctx.save_for_backward(attn_weights)
|
||||
ctx.num_heads = num_heads
|
||||
ctx.entropy_limit = entropy_limit
|
||||
ctx.grad_scale = grad_scale
|
||||
return attn_weights
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx,
|
||||
attn_weights_grad: Tensor):
|
||||
attn_weights, = ctx.saved_tensors
|
||||
num_heads = ctx.num_heads
|
||||
entropy_limit = ctx.entropy_limit
|
||||
grad_scale = ctx.grad_scale
|
||||
logging.info("Here4")
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
attn_weights_orig = attn_weights.to(torch.float32).detach()
|
||||
attn_weights_orig.requires_grad = True
|
||||
bsz = attn_weights_orig.shape[0] // num_heads
|
||||
seq_len = attn_weights_orig.shape[2]
|
||||
attn_weights = attn_weights_orig.reshape(bsz, num_heads,
|
||||
seq_len, seq_len)
|
||||
entropy = ((attn_weights + 1.0e-20).log() * attn_weights).sum(dim=-1)
|
||||
# entropy: (bsz, num_heads, seq_len)
|
||||
entropy = -entropy.mean(dim=(0,2))
|
||||
# entropy: (num_heads,)
|
||||
assert entropy.shape == (num_heads,)
|
||||
excess_entropy = (entropy - entropy_limit).relu()
|
||||
above_cutoff = (excess_entropy != 0) # tensor of shape (num_heads,)
|
||||
if random.random() < 0.001 or __name__ == "__main__":
|
||||
logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}")
|
||||
above_cutoff_sum = above_cutoff.to(torch.float32).sum()
|
||||
above_cutoff_sum = above_cutoff_sum.item()
|
||||
if above_cutoff_sum == 0:
|
||||
# grad would be 0. I'm guessing that checking this, and
|
||||
# incurring a CUDA sync, may save time relative to doing the
|
||||
# backprop of the entropy, but I'm not sure.
|
||||
return attn_weights_grad, None, None, None
|
||||
# Treat `excess_entropy` as a loss, to be minimized.
|
||||
excess_entropy.backward(gradient=torch.ones_like(excess_entropy))
|
||||
entropy_grad = attn_weights_orig.grad
|
||||
scale = ((grad_scale * above_cutoff_sum / num_heads) *
|
||||
(attn_weights_grad.to(torch.float32).norm() /
|
||||
(entropy_grad.norm() + 1.0e-20)))
|
||||
entropy_grad = entropy_grad * scale
|
||||
return attn_weights_grad + entropy_grad.to(attn_weights_grad.dtype), None, None, None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class EntropyPenalty(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: float,
|
||||
entropy_delta: float,
|
||||
prob: float,
|
||||
grad_scale: float):
|
||||
"""
|
||||
Args:
|
||||
num_heads: the number of attention heads in the self-attention module that
|
||||
this is attached to.
|
||||
entropy_delta: the delta from the maximum entropy, that we aim to
|
||||
decrease the entropy to if it is above. So the maximum entropy
|
||||
should be max(log(seq_len) - entropy_cutoff, 0.5 * log(seq_len));
|
||||
the second term is to make sure the limit never becomes tiny or
|
||||
negative in the case of short sequences.
|
||||
prob: the probability with which we apply this object.
|
||||
grad_scale: determines the scale on the gradient term from this object,
|
||||
relative to the rest of the gradient on the attention weights;
|
||||
will be divided by `prob`.
|
||||
"""
|
||||
super(EntropyPenalty, self).__init__()
|
||||
self.num_heads = num_heads
|
||||
self.entropy_delta = entropy_delta
|
||||
self.prob = prob
|
||||
self.grad_scale = grad_scale
|
||||
|
||||
def forward(self,
|
||||
attn_weights: Tensor) -> Tensor:
|
||||
"""
|
||||
In the forward pass, this function just returns the attention weights.
|
||||
In the backward pass, it will modify the gradients to ensure that the
|
||||
entropy of the attention heads is not too large. (We have noticed
|
||||
that too-large/almost-maximal entropy in the attention distribution
|
||||
is associated with heads that are not doing anything useful.
|
||||
|
||||
Args:
|
||||
attn_weights: the attention weights, after the log, with shape
|
||||
(batch_size * num_heads, seq_len, seq_len), satisfying:
|
||||
attn_weights.sum(dim=-1) == 1.
|
||||
Returns:
|
||||
the attn_weights, without any change. You should make sure
|
||||
you use the returned attention weights, or the graph will be freed
|
||||
and nothing will happen in backprop.
|
||||
"""
|
||||
logging.info("Here1")
|
||||
if not attn_weights.requires_grad or random.random() > self.prob:
|
||||
logging.info("Here2")
|
||||
return attn_weights
|
||||
else:
|
||||
seq_len = attn_weights.shape[2]
|
||||
max_entropy = math.log(seq_len)
|
||||
entropy_limit = max(max_entropy - self.entropy_delta,
|
||||
0.5 * max_entropy)
|
||||
return EntropyPenaltyFunction.apply(attn_weights,
|
||||
self.num_heads,
|
||||
entropy_limit,
|
||||
self.grad_scale / self.prob)
|
||||
|
||||
|
||||
|
||||
|
||||
class RelPositionMultiheadAttention(nn.Module):
|
||||
r"""Multi-Head Attention layer with relative position encoding
|
||||
|
||||
@ -851,6 +972,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
|
||||
initial_scale=0.05)
|
||||
|
||||
self.entropy_penalty = EntropyPenalty(num_heads,
|
||||
entropy_delta=0.8,
|
||||
prob=1.0 if __name__ == "__main__" else 0.2,
|
||||
grad_scale=0.01)
|
||||
|
||||
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
|
||||
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
|
||||
|
||||
@ -1204,6 +1330,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
Returns:
|
||||
output of the same shape as x, i.e. (seq_len, batch_size, embed_dim)
|
||||
"""
|
||||
attn_weights = self.entropy_penalty(attn_weights)
|
||||
|
||||
num_heads = self.num_heads
|
||||
(seq_len, bsz, embed_dim) = x.shape
|
||||
head_dim = embed_dim // (num_heads * 2)
|
||||
@ -1631,7 +1759,7 @@ def _test_conformer_main():
|
||||
torch.randn(batch_size, seq_len, feature_dim),
|
||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||
)
|
||||
f # to remove flake8 warnings
|
||||
f[0].sum().backward()
|
||||
c.eval()
|
||||
f = c(
|
||||
torch.randn(batch_size, seq_len, feature_dim),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user