mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Change cutoff for small_grad_norm
This commit is contained in:
parent
822465f73b
commit
80d51efd15
@ -844,7 +844,7 @@ class EntropyPenaltyFunction(torch.autograd.Function):
|
|||||||
assert entropy.shape == (num_heads,)
|
assert entropy.shape == (num_heads,)
|
||||||
excess_entropy = (entropy - entropy_limit).relu()
|
excess_entropy = (entropy - entropy_limit).relu()
|
||||||
above_cutoff = (entropy > 0) # tensor of shape (num_heads,)
|
above_cutoff = (entropy > 0) # tensor of shape (num_heads,)
|
||||||
small_grad_norm = (grad_norms < 0.5 * grad_norms.mean())
|
small_grad_norm = (grad_norms < grad_norms.mean())
|
||||||
will_penalize = torch.logical_and(above_cutoff, small_grad_norm)
|
will_penalize = torch.logical_and(above_cutoff, small_grad_norm)
|
||||||
if random.random() < 0.005 or __name__ == "__main__":
|
if random.random() < 0.005 or __name__ == "__main__":
|
||||||
logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}, small_grad_norm={small_grad_norm}, will_penalize={will_penalize}")
|
logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}, small_grad_norm={small_grad_norm}, will_penalize={will_penalize}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user