Remove debug statements

This commit is contained in:
Daniel Povey 2022-10-14 23:09:05 +08:00
parent a780984e6b
commit 394d4c95f9

View File

@ -812,7 +812,6 @@ class EntropyPenaltyFunction(torch.autograd.Function):
num_heads: int,
entropy_limit: float,
grad_scale: float) -> Tensor:
logging.info("Here3")
ctx.save_for_backward(attn_weights)
ctx.num_heads = num_heads
ctx.entropy_limit = entropy_limit
@ -826,7 +825,6 @@ class EntropyPenaltyFunction(torch.autograd.Function):
num_heads = ctx.num_heads
entropy_limit = ctx.entropy_limit
grad_scale = ctx.grad_scale
logging.info("Here4")
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
attn_weights_orig = attn_weights.to(torch.float32).detach()
@ -909,9 +907,7 @@ class EntropyPenalty(nn.Module):
you use the returned attention weights, or the graph will be freed
and nothing will happen in backprop.
"""
logging.info("Here1")
if not attn_weights.requires_grad or random.random() > self.prob:
logging.info("Here2")
return attn_weights
else:
seq_len = attn_weights.shape[2]