mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Changes to avoid bug in backward hooks, affecting diagnostics.
This commit is contained in:
parent
b37564c9c9
commit
6b3f9e5036
@ -32,6 +32,7 @@ from scaling import (
|
||||
ScaledConv1d,
|
||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||
Whiten,
|
||||
Identity,
|
||||
_diag,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
@ -864,8 +865,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
initial_scale=0.05)
|
||||
|
||||
# the following are for diagnosics only, see --print-diagnostics option
|
||||
self.copy_pos_query = nn.Identity()
|
||||
self.copy_query = nn.Identity()
|
||||
self.copy_pos_query = Identity()
|
||||
self.copy_query = Identity()
|
||||
|
||||
self.in_balancer = ActivationBalancer(3 * attention_dim,
|
||||
channel_dim=-1, max_abs=5.0)
|
||||
|
||||
@ -382,7 +382,7 @@ class ActivationBalancer(torch.nn.Module):
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or not x.requires_grad:
|
||||
return x
|
||||
return _no_op(x)
|
||||
|
||||
count = self.cpu_count
|
||||
self.cpu_count += 1
|
||||
@ -418,7 +418,7 @@ class ActivationBalancer(torch.nn.Module):
|
||||
x, scale_factor, sign_factor, self.channel_dim,
|
||||
)
|
||||
else:
|
||||
return x
|
||||
return _no_op(x)
|
||||
|
||||
|
||||
|
||||
@ -567,7 +567,7 @@ class Whiten(nn.Module):
|
||||
and nothing will happen in backprop.
|
||||
"""
|
||||
if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
|
||||
return x
|
||||
return _no_op(x)
|
||||
else:
|
||||
if hasattr(self, 'min_prob') and random.random() < 0.25:
|
||||
# occasionally switch between min_prob and max_prob, based on whether
|
||||
@ -583,9 +583,22 @@ class Whiten(nn.Module):
|
||||
self.whitening_limit,
|
||||
self.grad_scale)
|
||||
|
||||
def _no_op(x: Tensor) -> Tensor:
|
||||
if (torch.jit.is_scripting()):
|
||||
return x
|
||||
else:
|
||||
# a no-op function that will have a node in the autograd graph,
|
||||
# to avoid certain bugs relating to backward hooks
|
||||
return x.chunk(1, dim=-1)[0]
|
||||
|
||||
|
||||
|
||||
class Identity(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return _no_op(x)
|
||||
|
||||
class MaxEig(torch.nn.Module):
|
||||
"""
|
||||
@ -643,7 +656,7 @@ class MaxEig(torch.nn.Module):
|
||||
if (torch.jit.is_scripting() or
|
||||
self.max_var_per_eig <= 0 or
|
||||
random.random() > self.cur_prob):
|
||||
return x
|
||||
return _no_op(x)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
eps = 1.0e-20
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user