From f7aff4f507af67dbc9dc4c38035639ba290d1e1c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Nov 2022 21:36:36 +0800 Subject: [PATCH] Revert "Make sub-module dropped out independently." This reverts commit 3ff3f440ee6d2a367cc3cc45e40f8eb69d122861. --- .../pruned_transducer_stateless7/zipformer.py | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index c512ad6e2..4b6cab2d8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -487,30 +487,35 @@ class ZipformerEncoderLayer(nn.Module): # dropout rate for submodules that interact with time. dynamic_dropout = self.get_dynamic_dropout_rate() - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights = self.self_attn_weights( - src, - pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - ) + # multi-headed self-attention module + # TODO: make the various attention-using models be dropped + # out independently. + use_self_attn = (random.random() > dynamic_dropout) + if torch.jit.is_scripting() or use_self_attn: + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + ) - if torch.jit.is_scripting() or random.random() > dynamic_dropout: + if torch.jit.is_scripting() or use_self_attn: src = src + self.self_attn1( src, attn_weights) # convolution module - if torch.jit.is_scripting() or random.random() > dynamic_dropout: + if torch.jit.is_scripting() or use_self_attn: src = src + self.nonlin_attention_module(src, attn_weights[0:1]) src = src + self.feed_forward2(src) # pooling module - if torch.jit.is_scripting() or random.random() > dynamic_dropout: + if torch.jit.is_scripting() or use_self_attn: src = src + self.attention_squeeze1(src, attn_weights[1:2]) - if torch.jit.is_scripting() or random.random() > dynamic_dropout: + if torch.jit.is_scripting() or use_self_attn: src = src + self.self_attn2( src, attn_weights) @@ -520,9 +525,10 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.feed_forward3(src) # pooling module - if torch.jit.is_scripting() or random.random() > dynamic_dropout: + if torch.jit.is_scripting() or use_self_attn: src = src + self.attention_squeeze2(src, attn_weights[2:3]) + src = self.norm_final(self.balancer(src)) delta = src - src_orig