Changes to avoid bug in backward hooks, affecting diagnostics.

This commit is contained in:
Daniel Povey 2022-10-19 11:01:48 +08:00
parent b37564c9c9
commit 6b3f9e5036
2 changed files with 20 additions and 6 deletions

View File

@ -32,6 +32,7 @@ from scaling import (
ScaledConv1d,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
Whiten,
Identity,
_diag,
)
from torch import Tensor, nn
@ -864,8 +865,8 @@ class RelPositionMultiheadAttention(nn.Module):
initial_scale=0.05)
# the following are for diagnosics only, see --print-diagnostics option
self.copy_pos_query = nn.Identity()
self.copy_query = nn.Identity()
self.copy_pos_query = Identity()
self.copy_query = Identity()
self.in_balancer = ActivationBalancer(3 * attention_dim,
channel_dim=-1, max_abs=5.0)

View File

@ -382,7 +382,7 @@ class ActivationBalancer(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or not x.requires_grad:
return x
return _no_op(x)
count = self.cpu_count
self.cpu_count += 1
@ -418,7 +418,7 @@ class ActivationBalancer(torch.nn.Module):
x, scale_factor, sign_factor, self.channel_dim,
)
else:
return x
return _no_op(x)
@ -567,7 +567,7 @@ class Whiten(nn.Module):
and nothing will happen in backprop.
"""
if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
return x
return _no_op(x)
else:
if hasattr(self, 'min_prob') and random.random() < 0.25:
# occasionally switch between min_prob and max_prob, based on whether
@ -583,9 +583,22 @@ class Whiten(nn.Module):
self.whitening_limit,
self.grad_scale)
def _no_op(x: Tensor) -> Tensor:
if (torch.jit.is_scripting()):
return x
else:
# a no-op function that will have a node in the autograd graph,
# to avoid certain bugs relating to backward hooks
return x.chunk(1, dim=-1)[0]
class Identity(torch.nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return _no_op(x)
class MaxEig(torch.nn.Module):
"""
@ -643,7 +656,7 @@ class MaxEig(torch.nn.Module):
if (torch.jit.is_scripting() or
self.max_var_per_eig <= 0 or
random.random() > self.cur_prob):
return x
return _no_op(x)
with torch.cuda.amp.autocast(enabled=False):
eps = 1.0e-20