Penalize attention-weight entropies above a limit.

This commit is contained in:
Daniel Povey 2022-10-14 23:01:30 +08:00
parent 1812f6cb28
commit a780984e6b

View File

@ -805,6 +805,127 @@ class RelPositionalEncoding(torch.nn.Module):
return self.dropout(pos_emb)
class EntropyPenaltyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx,
attn_weights: Tensor,
num_heads: int,
entropy_limit: float,
grad_scale: float) -> Tensor:
logging.info("Here3")
ctx.save_for_backward(attn_weights)
ctx.num_heads = num_heads
ctx.entropy_limit = entropy_limit
ctx.grad_scale = grad_scale
return attn_weights
@staticmethod
def backward(ctx,
attn_weights_grad: Tensor):
attn_weights, = ctx.saved_tensors
num_heads = ctx.num_heads
entropy_limit = ctx.entropy_limit
grad_scale = ctx.grad_scale
logging.info("Here4")
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
attn_weights_orig = attn_weights.to(torch.float32).detach()
attn_weights_orig.requires_grad = True
bsz = attn_weights_orig.shape[0] // num_heads
seq_len = attn_weights_orig.shape[2]
attn_weights = attn_weights_orig.reshape(bsz, num_heads,
seq_len, seq_len)
entropy = ((attn_weights + 1.0e-20).log() * attn_weights).sum(dim=-1)
# entropy: (bsz, num_heads, seq_len)
entropy = -entropy.mean(dim=(0,2))
# entropy: (num_heads,)
assert entropy.shape == (num_heads,)
excess_entropy = (entropy - entropy_limit).relu()
above_cutoff = (excess_entropy != 0) # tensor of shape (num_heads,)
if random.random() < 0.001 or __name__ == "__main__":
logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}")
above_cutoff_sum = above_cutoff.to(torch.float32).sum()
above_cutoff_sum = above_cutoff_sum.item()
if above_cutoff_sum == 0:
# grad would be 0. I'm guessing that checking this, and
# incurring a CUDA sync, may save time relative to doing the
# backprop of the entropy, but I'm not sure.
return attn_weights_grad, None, None, None
# Treat `excess_entropy` as a loss, to be minimized.
excess_entropy.backward(gradient=torch.ones_like(excess_entropy))
entropy_grad = attn_weights_orig.grad
scale = ((grad_scale * above_cutoff_sum / num_heads) *
(attn_weights_grad.to(torch.float32).norm() /
(entropy_grad.norm() + 1.0e-20)))
entropy_grad = entropy_grad * scale
return attn_weights_grad + entropy_grad.to(attn_weights_grad.dtype), None, None, None
class EntropyPenalty(nn.Module):
def __init__(
self,
num_heads: float,
entropy_delta: float,
prob: float,
grad_scale: float):
"""
Args:
num_heads: the number of attention heads in the self-attention module that
this is attached to.
entropy_delta: the delta from the maximum entropy, that we aim to
decrease the entropy to if it is above. So the maximum entropy
should be max(log(seq_len) - entropy_cutoff, 0.5 * log(seq_len));
the second term is to make sure the limit never becomes tiny or
negative in the case of short sequences.
prob: the probability with which we apply this object.
grad_scale: determines the scale on the gradient term from this object,
relative to the rest of the gradient on the attention weights;
will be divided by `prob`.
"""
super(EntropyPenalty, self).__init__()
self.num_heads = num_heads
self.entropy_delta = entropy_delta
self.prob = prob
self.grad_scale = grad_scale
def forward(self,
attn_weights: Tensor) -> Tensor:
"""
In the forward pass, this function just returns the attention weights.
In the backward pass, it will modify the gradients to ensure that the
entropy of the attention heads is not too large. (We have noticed
that too-large/almost-maximal entropy in the attention distribution
is associated with heads that are not doing anything useful.
Args:
attn_weights: the attention weights, after the log, with shape
(batch_size * num_heads, seq_len, seq_len), satisfying:
attn_weights.sum(dim=-1) == 1.
Returns:
the attn_weights, without any change. You should make sure
you use the returned attention weights, or the graph will be freed
and nothing will happen in backprop.
"""
logging.info("Here1")
if not attn_weights.requires_grad or random.random() > self.prob:
logging.info("Here2")
return attn_weights
else:
seq_len = attn_weights.shape[2]
max_entropy = math.log(seq_len)
entropy_limit = max(max_entropy - self.entropy_delta,
0.5 * max_entropy)
return EntropyPenaltyFunction.apply(attn_weights,
self.num_heads,
entropy_limit,
self.grad_scale / self.prob)
class RelPositionMultiheadAttention(nn.Module):
r"""Multi-Head Attention layer with relative position encoding
@ -851,6 +972,11 @@ class RelPositionMultiheadAttention(nn.Module):
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
initial_scale=0.05)
self.entropy_penalty = EntropyPenalty(num_heads,
entropy_delta=0.8,
prob=1.0 if __name__ == "__main__" else 0.2,
grad_scale=0.01)
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
@ -1204,6 +1330,8 @@ class RelPositionMultiheadAttention(nn.Module):
Returns:
output of the same shape as x, i.e. (seq_len, batch_size, embed_dim)
"""
attn_weights = self.entropy_penalty(attn_weights)
num_heads = self.num_heads
(seq_len, bsz, embed_dim) = x.shape
head_dim = embed_dim // (num_heads * 2)
@ -1631,7 +1759,7 @@ def _test_conformer_main():
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
)
f # to remove flake8 warnings
f[0].sum().backward()
c.eval()
f = c(
torch.randn(batch_size, seq_len, feature_dim),