From 6b3f9e50362b11ca9473469c7e58cbd2d9c941a1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Oct 2022 11:01:48 +0800 Subject: [PATCH] Changes to avoid bug in backward hooks, affecting diagnostics. --- .../pruned_transducer_stateless7/conformer.py | 5 +++-- .../pruned_transducer_stateless7/scaling.py | 21 +++++++++++++++---- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index e1a91bae9..f15991b20 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -32,6 +32,7 @@ from scaling import ( ScaledConv1d, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. Whiten, + Identity, _diag, ) from torch import Tensor, nn @@ -864,8 +865,8 @@ class RelPositionMultiheadAttention(nn.Module): initial_scale=0.05) # the following are for diagnosics only, see --print-diagnostics option - self.copy_pos_query = nn.Identity() - self.copy_query = nn.Identity() + self.copy_pos_query = Identity() + self.copy_query = Identity() self.in_balancer = ActivationBalancer(3 * attention_dim, channel_dim=-1, max_abs=5.0) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 5866ee517..d65b5659a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -382,7 +382,7 @@ class ActivationBalancer(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting() or not x.requires_grad: - return x + return _no_op(x) count = self.cpu_count self.cpu_count += 1 @@ -418,7 +418,7 @@ class ActivationBalancer(torch.nn.Module): x, scale_factor, sign_factor, self.channel_dim, ) else: - return x + return _no_op(x) @@ -567,7 +567,7 @@ class Whiten(nn.Module): and nothing will happen in backprop. """ if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: - return x + return _no_op(x) else: if hasattr(self, 'min_prob') and random.random() < 0.25: # occasionally switch between min_prob and max_prob, based on whether @@ -583,9 +583,22 @@ class Whiten(nn.Module): self.whitening_limit, self.grad_scale) +def _no_op(x: Tensor) -> Tensor: + if (torch.jit.is_scripting()): + return x + else: + # a no-op function that will have a node in the autograd graph, + # to avoid certain bugs relating to backward hooks + return x.chunk(1, dim=-1)[0] +class Identity(torch.nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return _no_op(x) class MaxEig(torch.nn.Module): """ @@ -643,7 +656,7 @@ class MaxEig(torch.nn.Module): if (torch.jit.is_scripting() or self.max_var_per_eig <= 0 or random.random() > self.cur_prob): - return x + return _no_op(x) with torch.cuda.amp.autocast(enabled=False): eps = 1.0e-20