From 9cf5d92f39485dfd30c3f6ccf9a45c4bcc397b6e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 28 Nov 2022 16:24:24 +0800 Subject: [PATCH] Have nonlin_attention and attention_squeeze operate only on every other layer. --- .../ASR/pruned_transducer_stateless7/zipformer.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 6c7b73f2b..51ab2946c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -450,6 +450,12 @@ class ZipformerEncoderLayer(nn.Module): def remove_attention_weights(self): self.self_attn_weights = None + def remove_nonlin_attention(self): + self.nonlin_attention_module = None + + def remove_attention_squeeze(self): + self.attention_squeeze = None + def get_bypass_scale(self): if torch.jit.is_scripting() or not self.training: return self.bypass_scale @@ -520,14 +526,14 @@ class ZipformerEncoderLayer(nn.Module): first_attn_weights = first_attn_weights * (1.0 / first_attn_weights.sum(dim=-1, keepdim=True)) first_attn_weights = first_attn_weights.expand(3, -1, -1, -1) - if torch.jit.is_scripting() or use_self_attn: + if (torch.jit.is_scripting() or use_self_attn) and self.nonlin_attention_module is not None: src = src + self.nonlin_attention_module(src, first_attn_weights[0:1]) src = src + self.feed_forward1(src) # pooling module - if torch.jit.is_scripting() or use_self_attn: + if (torch.jit.is_scripting() or use_self_attn) and self.attention_squeeze is not None: src = src + self.attention_squeeze(src, first_attn_weights[1:2]) if torch.jit.is_scripting() or use_self_attn: @@ -598,6 +604,10 @@ class ZipformerEncoder(nn.Module): cur_begin = cur_end if i % attention_share_layers != 0: self.layers[i].remove_attention_weights() + if i % attention_share_layers == 0: + self.layers[i].remove_nonlin_attention() + else: + self.layers[i].remove_attention_squeeze() def forward( self,