Remove debug statements

This commit is contained in:
Daniel Povey 2022-10-14 23:09:05 +08:00
parent a780984e6b
commit 394d4c95f9

View File

@ -812,7 +812,6 @@ class EntropyPenaltyFunction(torch.autograd.Function):
num_heads: int, num_heads: int,
entropy_limit: float, entropy_limit: float,
grad_scale: float) -> Tensor: grad_scale: float) -> Tensor:
logging.info("Here3")
ctx.save_for_backward(attn_weights) ctx.save_for_backward(attn_weights)
ctx.num_heads = num_heads ctx.num_heads = num_heads
ctx.entropy_limit = entropy_limit ctx.entropy_limit = entropy_limit
@ -826,7 +825,6 @@ class EntropyPenaltyFunction(torch.autograd.Function):
num_heads = ctx.num_heads num_heads = ctx.num_heads
entropy_limit = ctx.entropy_limit entropy_limit = ctx.entropy_limit
grad_scale = ctx.grad_scale grad_scale = ctx.grad_scale
logging.info("Here4")
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
attn_weights_orig = attn_weights.to(torch.float32).detach() attn_weights_orig = attn_weights.to(torch.float32).detach()
@ -909,9 +907,7 @@ class EntropyPenalty(nn.Module):
you use the returned attention weights, or the graph will be freed you use the returned attention weights, or the graph will be freed
and nothing will happen in backprop. and nothing will happen in backprop.
""" """
logging.info("Here1")
if not attn_weights.requires_grad or random.random() > self.prob: if not attn_weights.requires_grad or random.random() > self.prob:
logging.info("Here2")
return attn_weights return attn_weights
else: else:
seq_len = attn_weights.shape[2] seq_len = attn_weights.shape[2]