Bug fixes; change debug freq

This commit is contained in:
Daniel Povey 2022-10-14 23:25:29 +08:00
parent 0557dbb720
commit 822465f73b

View File

@ -846,17 +846,16 @@ class EntropyPenaltyFunction(torch.autograd.Function):
above_cutoff = (entropy > 0) # tensor of shape (num_heads,)
small_grad_norm = (grad_norms < 0.5 * grad_norms.mean())
will_penalize = torch.logical_and(above_cutoff, small_grad_norm)
if random.random() < 0.001 or __name__ == "__main__":
if random.random() < 0.005 or __name__ == "__main__":
logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}, small_grad_norm={small_grad_norm}, will_penalize={will_penalize}")
will_penalize_sum = will_penalize.to(torch.float32).sum()
will_penalize_sum = will_penalize_sum.item()
will_penalize_sum = will_penalize.to(torch.float32).sum().item()
if will_penalize_sum == 0:
# grad would be 0. I'm guessing that checking this, and
# incurring a CUDA sync, may save time relative to doing the
# backprop of the entropy, but I'm not sure.
return attn_weights_grad, None, None, None
# Treat `excess_entropy` as a loss, to be minimized.
excess_entropy.backward(gradient=will_penalize_sum.to(torch.float32))
excess_entropy.backward(gradient=will_penalize.to(torch.float32))
entropy_grad = attn_weights_orig.grad
scale = ((grad_scale * will_penalize_sum / num_heads) *
(attn_weights_grad.to(torch.float32).norm() /