Remove debug statements
This commit is contained in:
parent
a780984e6b
commit
394d4c95f9
@ -812,7 +812,6 @@ class EntropyPenaltyFunction(torch.autograd.Function):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
entropy_limit: float,
|
entropy_limit: float,
|
||||||
grad_scale: float) -> Tensor:
|
grad_scale: float) -> Tensor:
|
||||||
logging.info("Here3")
|
|
||||||
ctx.save_for_backward(attn_weights)
|
ctx.save_for_backward(attn_weights)
|
||||||
ctx.num_heads = num_heads
|
ctx.num_heads = num_heads
|
||||||
ctx.entropy_limit = entropy_limit
|
ctx.entropy_limit = entropy_limit
|
||||||
@ -826,7 +825,6 @@ class EntropyPenaltyFunction(torch.autograd.Function):
|
|||||||
num_heads = ctx.num_heads
|
num_heads = ctx.num_heads
|
||||||
entropy_limit = ctx.entropy_limit
|
entropy_limit = ctx.entropy_limit
|
||||||
grad_scale = ctx.grad_scale
|
grad_scale = ctx.grad_scale
|
||||||
logging.info("Here4")
|
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
attn_weights_orig = attn_weights.to(torch.float32).detach()
|
attn_weights_orig = attn_weights.to(torch.float32).detach()
|
||||||
@ -909,9 +907,7 @@ class EntropyPenalty(nn.Module):
|
|||||||
you use the returned attention weights, or the graph will be freed
|
you use the returned attention weights, or the graph will be freed
|
||||||
and nothing will happen in backprop.
|
and nothing will happen in backprop.
|
||||||
"""
|
"""
|
||||||
logging.info("Here1")
|
|
||||||
if not attn_weights.requires_grad or random.random() > self.prob:
|
if not attn_weights.requires_grad or random.random() > self.prob:
|
||||||
logging.info("Here2")
|
|
||||||
return attn_weights
|
return attn_weights
|
||||||
else:
|
else:
|
||||||
seq_len = attn_weights.shape[2]
|
seq_len = attn_weights.shape[2]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user