use larger delta but only penalize if small grad norm

This commit is contained in:
Daniel Povey 2022-10-14 23:23:20 +08:00
parent 394d4c95f9
commit 0557dbb720

View File

@ -833,26 +833,32 @@ class EntropyPenaltyFunction(torch.autograd.Function):
seq_len = attn_weights_orig.shape[2]
attn_weights = attn_weights_orig.reshape(bsz, num_heads,
seq_len, seq_len)
grad_norms = attn_weights_grad.detach().reshape(
bsz, num_heads, seq_len * seq_len).norm(dim=(0,2))
entropy = ((attn_weights + 1.0e-20).log() * attn_weights).sum(dim=-1)
# entropy: (bsz, num_heads, seq_len)
entropy = -entropy.mean(dim=(0,2))
# entropy: (num_heads,)
assert entropy.shape == (num_heads,)
excess_entropy = (entropy - entropy_limit).relu()
above_cutoff = (excess_entropy != 0) # tensor of shape (num_heads,)
above_cutoff = (entropy > 0) # tensor of shape (num_heads,)
small_grad_norm = (grad_norms < 0.5 * grad_norms.mean())
will_penalize = torch.logical_and(above_cutoff, small_grad_norm)
if random.random() < 0.001 or __name__ == "__main__":
logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}")
above_cutoff_sum = above_cutoff.to(torch.float32).sum()
above_cutoff_sum = above_cutoff_sum.item()
if above_cutoff_sum == 0:
logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}, small_grad_norm={small_grad_norm}, will_penalize={will_penalize}")
will_penalize_sum = will_penalize.to(torch.float32).sum()
will_penalize_sum = will_penalize_sum.item()
if will_penalize_sum == 0:
# grad would be 0. I'm guessing that checking this, and
# incurring a CUDA sync, may save time relative to doing the
# backprop of the entropy, but I'm not sure.
return attn_weights_grad, None, None, None
# Treat `excess_entropy` as a loss, to be minimized.
excess_entropy.backward(gradient=torch.ones_like(excess_entropy))
excess_entropy.backward(gradient=will_penalize_sum.to(torch.float32))
entropy_grad = attn_weights_orig.grad
scale = ((grad_scale * above_cutoff_sum / num_heads) *
scale = ((grad_scale * will_penalize_sum / num_heads) *
(attn_weights_grad.to(torch.float32).norm() /
(entropy_grad.norm() + 1.0e-20)))
entropy_grad = entropy_grad * scale
@ -969,7 +975,7 @@ class RelPositionMultiheadAttention(nn.Module):
initial_scale=0.05)
self.entropy_penalty = EntropyPenalty(num_heads,
entropy_delta=0.8,
entropy_delta=1.5,
prob=1.0 if __name__ == "__main__" else 0.2,
grad_scale=0.01)