diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index aa6e3d1ae..05a6ea933 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -383,6 +383,21 @@ class ZipformerEncoderLayer(nn.Module): (self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min)) return self.bypass_scale.clamp(min=clamp_min, max=1.0) + def get_dynamic_dropout_rate(self): + # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this + # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable + # at the beginning, by making the network focus on the feedforward modules. + if torch.jit.is_scripting() or not self.training: + return 0.0 + warmup_period = 2000.0 + initial_dropout_rate = 0.2 + final_dropout_rate = 0.0 + if self.batch_count > warmup_period: + return final_dropout_rate + else: + return (initial_dropout_rate - + (initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period)) + def forward( self, src: Tensor, @@ -412,28 +427,37 @@ class ZipformerEncoderLayer(nn.Module): # macaron style feed forward module src = src + self.feed_forward1(src) + # dropout rate for submodules that interact with time. + dynamic_dropout = self.get_dynamic_dropout_rate() + # pooling module - src = src + self.pooling(src, - key_padding_mask=src_key_padding_mask) + if torch.jit.is_scripting() or random.random() > dynamic_dropout: + src = src + self.pooling(src, + key_padding_mask=src_key_padding_mask) # multi-headed self-attention module - src_att, attn_weights = self.self_attn( - src, - pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - ) - src = src + src_att + use_self_attn = (random.random() > dynamic_dropout) + if torch.jit.is_scripting() or use_self_attn: + src_att, attn_weights = self.self_attn( + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + ) + src = src + src_att # convolution module - src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask) + if torch.jit.is_scripting() or random.random() > dynamic_dropout: + src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward2(src) - src = src + self.self_attn.forward2(src, attn_weights) + if torch.jit.is_scripting() or use_self_attn: + src = src + self.self_attn.forward2(src, attn_weights) - src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask) + if torch.jit.is_scripting() or random.random() > dynamic_dropout: + src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward3(src)