From 822465f73bf8f6a931ce524a840127dd07144328 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 14 Oct 2022 23:25:29 +0800 Subject: [PATCH] Bug fixes; change debug freq --- .../ASR/pruned_transducer_stateless7/conformer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index a0f7ade04..6183a50d4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -846,17 +846,16 @@ class EntropyPenaltyFunction(torch.autograd.Function): above_cutoff = (entropy > 0) # tensor of shape (num_heads,) small_grad_norm = (grad_norms < 0.5 * grad_norms.mean()) will_penalize = torch.logical_and(above_cutoff, small_grad_norm) - if random.random() < 0.001 or __name__ == "__main__": + if random.random() < 0.005 or __name__ == "__main__": logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}, small_grad_norm={small_grad_norm}, will_penalize={will_penalize}") - will_penalize_sum = will_penalize.to(torch.float32).sum() - will_penalize_sum = will_penalize_sum.item() + will_penalize_sum = will_penalize.to(torch.float32).sum().item() if will_penalize_sum == 0: # grad would be 0. I'm guessing that checking this, and # incurring a CUDA sync, may save time relative to doing the # backprop of the entropy, but I'm not sure. return attn_weights_grad, None, None, None # Treat `excess_entropy` as a loss, to be minimized. - excess_entropy.backward(gradient=will_penalize_sum.to(torch.float32)) + excess_entropy.backward(gradient=will_penalize.to(torch.float32)) entropy_grad = attn_weights_orig.grad scale = ((grad_scale * will_penalize_sum / num_heads) * (attn_weights_grad.to(torch.float32).norm() /