From 3ff3f440ee6d2a367cc3cc45e40f8eb69d122861 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 9 Nov 2022 14:15:56 +0800 Subject: [PATCH] Make sub-module dropped out independently. --- .../pruned_transducer_stateless7/zipformer.py | 30 ++++++++----------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index a194c814f..1db062bd4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -485,35 +485,30 @@ class ZipformerEncoderLayer(nn.Module): # dropout rate for submodules that interact with time. dynamic_dropout = self.get_dynamic_dropout_rate() - # multi-headed self-attention module - # TODO: make the various attention-using models be dropped - # out independently. - use_self_attn = (random.random() > dynamic_dropout) - if torch.jit.is_scripting() or use_self_attn: - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights = self.self_attn_weights( - src, - pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - ) + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + ) - if torch.jit.is_scripting() or use_self_attn: + if torch.jit.is_scripting() or random.random() > dynamic_dropout: src = src + self.self_attn1( src, attn_weights) # convolution module - if torch.jit.is_scripting() or use_self_attn: + if torch.jit.is_scripting() or random.random() > dynamic_dropout: src = src + self.nonlin_attention_module(src, attn_weights[0:1]) src = src + self.feed_forward2(src) # pooling module - if torch.jit.is_scripting() or use_self_attn: + if torch.jit.is_scripting() or random.random() > dynamic_dropout: src = src + self.attention_squeeze1(src, attn_weights[1:2]) - if torch.jit.is_scripting() or use_self_attn: + if torch.jit.is_scripting() or random.random() > dynamic_dropout: src = src + self.self_attn2( src, attn_weights) @@ -523,10 +518,9 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.feed_forward3(src) # pooling module - if torch.jit.is_scripting() or use_self_attn: + if torch.jit.is_scripting() or random.random() > dynamic_dropout: src = src + self.attention_squeeze2(src, attn_weights[2:3]) - src = self.norm_final(self.balancer(src)) delta = src - src_orig