From b90d8aabdea59ff525393136bf97744e0eb78c13 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 29 Nov 2022 15:38:55 +0800 Subject: [PATCH] Revert the alternate-layers-only thing for nonlin_attention and attention_squeeze --- .../ASR/pruned_transducer_stateless7/zipformer.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 700f64963..0a002a53c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -451,12 +451,6 @@ class ZipformerEncoderLayer(nn.Module): def remove_attention_weights(self): self.self_attn_weights = None - def remove_nonlin_attention(self): - self.nonlin_attention_module = None - - def remove_attention_squeeze(self): - self.attention_squeeze = None - def get_bypass_scale(self): if torch.jit.is_scripting() or not self.training: return self.bypass_scale @@ -527,14 +521,14 @@ class ZipformerEncoderLayer(nn.Module): first_attn_weights = first_attn_weights * (1.0 / first_attn_weights.sum(dim=-1, keepdim=True)) first_attn_weights = first_attn_weights.expand(3, -1, -1, -1) - if (torch.jit.is_scripting() or use_self_attn) and self.nonlin_attention_module is not None: + if torch.jit.is_scripting() or use_self_attn: src = src + self.nonlin_attention_module(src, first_attn_weights[0:1]) src = src + self.feed_forward1(src) # pooling module - if (torch.jit.is_scripting() or use_self_attn) and self.attention_squeeze is not None: + if torch.jit.is_scripting() or use_self_attn: src = src + self.attention_squeeze(src, first_attn_weights[1:2]) if torch.jit.is_scripting() or use_self_attn: @@ -605,10 +599,6 @@ class ZipformerEncoder(nn.Module): cur_begin = cur_end if i % attention_share_layers != 0: self.layers[i].remove_attention_weights() - if i % attention_share_layers == 0: - self.layers[i].remove_nonlin_attention() - else: - self.layers[i].remove_attention_squeeze() def forward( self,