mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
use larger delta but only penalize if small grad norm
This commit is contained in:
parent
394d4c95f9
commit
0557dbb720
@ -833,26 +833,32 @@ class EntropyPenaltyFunction(torch.autograd.Function):
|
||||
seq_len = attn_weights_orig.shape[2]
|
||||
attn_weights = attn_weights_orig.reshape(bsz, num_heads,
|
||||
seq_len, seq_len)
|
||||
|
||||
grad_norms = attn_weights_grad.detach().reshape(
|
||||
bsz, num_heads, seq_len * seq_len).norm(dim=(0,2))
|
||||
|
||||
entropy = ((attn_weights + 1.0e-20).log() * attn_weights).sum(dim=-1)
|
||||
# entropy: (bsz, num_heads, seq_len)
|
||||
entropy = -entropy.mean(dim=(0,2))
|
||||
# entropy: (num_heads,)
|
||||
assert entropy.shape == (num_heads,)
|
||||
excess_entropy = (entropy - entropy_limit).relu()
|
||||
above_cutoff = (excess_entropy != 0) # tensor of shape (num_heads,)
|
||||
above_cutoff = (entropy > 0) # tensor of shape (num_heads,)
|
||||
small_grad_norm = (grad_norms < 0.5 * grad_norms.mean())
|
||||
will_penalize = torch.logical_and(above_cutoff, small_grad_norm)
|
||||
if random.random() < 0.001 or __name__ == "__main__":
|
||||
logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}")
|
||||
above_cutoff_sum = above_cutoff.to(torch.float32).sum()
|
||||
above_cutoff_sum = above_cutoff_sum.item()
|
||||
if above_cutoff_sum == 0:
|
||||
logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}, small_grad_norm={small_grad_norm}, will_penalize={will_penalize}")
|
||||
will_penalize_sum = will_penalize.to(torch.float32).sum()
|
||||
will_penalize_sum = will_penalize_sum.item()
|
||||
if will_penalize_sum == 0:
|
||||
# grad would be 0. I'm guessing that checking this, and
|
||||
# incurring a CUDA sync, may save time relative to doing the
|
||||
# backprop of the entropy, but I'm not sure.
|
||||
return attn_weights_grad, None, None, None
|
||||
# Treat `excess_entropy` as a loss, to be minimized.
|
||||
excess_entropy.backward(gradient=torch.ones_like(excess_entropy))
|
||||
excess_entropy.backward(gradient=will_penalize_sum.to(torch.float32))
|
||||
entropy_grad = attn_weights_orig.grad
|
||||
scale = ((grad_scale * above_cutoff_sum / num_heads) *
|
||||
scale = ((grad_scale * will_penalize_sum / num_heads) *
|
||||
(attn_weights_grad.to(torch.float32).norm() /
|
||||
(entropy_grad.norm() + 1.0e-20)))
|
||||
entropy_grad = entropy_grad * scale
|
||||
@ -969,7 +975,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
initial_scale=0.05)
|
||||
|
||||
self.entropy_penalty = EntropyPenalty(num_heads,
|
||||
entropy_delta=0.8,
|
||||
entropy_delta=1.5,
|
||||
prob=1.0 if __name__ == "__main__" else 0.2,
|
||||
grad_scale=0.01)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user