mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove debug statements
This commit is contained in:
parent
a780984e6b
commit
394d4c95f9
@ -812,7 +812,6 @@ class EntropyPenaltyFunction(torch.autograd.Function):
|
||||
num_heads: int,
|
||||
entropy_limit: float,
|
||||
grad_scale: float) -> Tensor:
|
||||
logging.info("Here3")
|
||||
ctx.save_for_backward(attn_weights)
|
||||
ctx.num_heads = num_heads
|
||||
ctx.entropy_limit = entropy_limit
|
||||
@ -826,7 +825,6 @@ class EntropyPenaltyFunction(torch.autograd.Function):
|
||||
num_heads = ctx.num_heads
|
||||
entropy_limit = ctx.entropy_limit
|
||||
grad_scale = ctx.grad_scale
|
||||
logging.info("Here4")
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
attn_weights_orig = attn_weights.to(torch.float32).detach()
|
||||
@ -909,9 +907,7 @@ class EntropyPenalty(nn.Module):
|
||||
you use the returned attention weights, or the graph will be freed
|
||||
and nothing will happen in backprop.
|
||||
"""
|
||||
logging.info("Here1")
|
||||
if not attn_weights.requires_grad or random.random() > self.prob:
|
||||
logging.info("Here2")
|
||||
return attn_weights
|
||||
else:
|
||||
seq_len = attn_weights.shape[2]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user