From 0557dbb72076465a4edb35250d5f316a11331c6b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Oct 2022 23:23:20 +0800 Subject: [PATCH] use larger delta but only penalize if small grad norm --- .../pruned_transducer_stateless7/conformer.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 673470f59..a0f7ade04 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -833,26 +833,32 @@ class EntropyPenaltyFunction(torch.autograd.Function): seq_len = attn_weights_orig.shape[2] attn_weights = attn_weights_orig.reshape(bsz, num_heads, seq_len, seq_len) + + grad_norms = attn_weights_grad.detach().reshape( + bsz, num_heads, seq_len * seq_len).norm(dim=(0,2)) + entropy = ((attn_weights + 1.0e-20).log() * attn_weights).sum(dim=-1) # entropy: (bsz, num_heads, seq_len) entropy = -entropy.mean(dim=(0,2)) # entropy: (num_heads,) assert entropy.shape == (num_heads,) excess_entropy = (entropy - entropy_limit).relu() - above_cutoff = (excess_entropy != 0) # tensor of shape (num_heads,) + above_cutoff = (entropy > 0) # tensor of shape (num_heads,) + small_grad_norm = (grad_norms < 0.5 * grad_norms.mean()) + will_penalize = torch.logical_and(above_cutoff, small_grad_norm) if random.random() < 0.001 or __name__ == "__main__": - logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}") - above_cutoff_sum = above_cutoff.to(torch.float32).sum() - above_cutoff_sum = above_cutoff_sum.item() - if above_cutoff_sum == 0: + logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}, small_grad_norm={small_grad_norm}, will_penalize={will_penalize}") + will_penalize_sum = will_penalize.to(torch.float32).sum() + will_penalize_sum = will_penalize_sum.item() + if will_penalize_sum == 0: # grad would be 0. I'm guessing that checking this, and # incurring a CUDA sync, may save time relative to doing the # backprop of the entropy, but I'm not sure. return attn_weights_grad, None, None, None # Treat `excess_entropy` as a loss, to be minimized. - excess_entropy.backward(gradient=torch.ones_like(excess_entropy)) + excess_entropy.backward(gradient=will_penalize_sum.to(torch.float32)) entropy_grad = attn_weights_orig.grad - scale = ((grad_scale * above_cutoff_sum / num_heads) * + scale = ((grad_scale * will_penalize_sum / num_heads) * (attn_weights_grad.to(torch.float32).norm() / (entropy_grad.norm() + 1.0e-20))) entropy_grad = entropy_grad * scale @@ -969,7 +975,7 @@ class RelPositionMultiheadAttention(nn.Module): initial_scale=0.05) self.entropy_penalty = EntropyPenalty(num_heads, - entropy_delta=0.8, + entropy_delta=1.5, prob=1.0 if __name__ == "__main__" else 0.2, grad_scale=0.01)