diff --git a/egs/librispeech/ASR/zipformer_ctc_attn/attention_decoder.py b/egs/librispeech/ASR/zipformer_ctc_attn/attention_decoder.py index 6a040accd..8bb96c002 100644 --- a/egs/librispeech/ASR/zipformer_ctc_attn/attention_decoder.py +++ b/egs/librispeech/ASR/zipformer_ctc_attn/attention_decoder.py @@ -438,7 +438,7 @@ class DecoderLayer(nn.Module): ) def get_bypass_scale(self): - if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): + if not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): return self.bypass_scale if random.random() < 0.1: # ensure we get grads if self.bypass_scale becomes out of range @@ -565,11 +565,6 @@ class MultiHeadedAttention(nn.Module): grad_scale=0.025, ) - # the following are for diagnosics only, see --print-diagnostics option. - # they only copy their inputs. - self.copy_pos_query = Identity() - self.copy_query = Identity() - self.out_proj = ScaledLinear( attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 ) @@ -597,7 +592,6 @@ class MultiHeadedAttention(nn.Module): k = self.linear_k(key) v = self.linear_v(value) - q = self.copy_query(q) # for diagnostics only, does nothing. k = self.whiten_k(k) # does nothing in the forward pass. v = self.whiten_v(v) # does nothing in the forward pass.