From fe51eea397f126788e09961445468c55c9060cb5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 17 Nov 2022 10:54:29 +0800 Subject: [PATCH] Implement a form of dropout for squeeze_weights, dropout-to-constant. --- .../pruned_transducer_stateless7/zipformer.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index b406da0c7..9590de27f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -366,6 +366,7 @@ class ZipformerEncoderLayer(nn.Module): # to work correctly. layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0), + squeeze_const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0), bypass_max: FloatLike = 1.0, ) -> None: @@ -380,7 +381,7 @@ class ZipformerEncoderLayer(nn.Module): # ever becoming zero. self.bypass_min = copy.deepcopy(bypass_min) self.bypass_max = copy.deepcopy(bypass_max) - + self.squeeze_const_attention_rate = copy.deepcopy(squeeze_const_attention_rate) self.self_attn_weights = RelPositionMultiheadAttentionWeights( embed_dim, pos_dim=pos_dim, num_heads=num_heads, @@ -478,6 +479,17 @@ class ZipformerEncoderLayer(nn.Module): key_padding_mask=src_key_padding_mask, ) + + squeeze_weights = attn_weights[1:2] + if random.random() < float(self.squeeze_const_attention_rate): + # this form of dropout makes the attention-weights used for the + # squeeze-excite modules constant wherever they are not masked. The intention + # is to encourage these modules to do something similar to an averaging-over-time + # operation. + squeeze_weights = (squeeze_weights > 0.0).to(squeeze_weights.dtype) + # make sure they sum to 1 over the last axis. + squeeze_weights = squeeze_weights * (1.0 / squeeze_weights.sum(dim=-1, keepdim=True)) + if torch.jit.is_scripting() or use_self_attn: src = src + self.nonlin_attention_module(src, attn_weights[0:1]) @@ -487,7 +499,7 @@ class ZipformerEncoderLayer(nn.Module): # pooling module if torch.jit.is_scripting() or use_self_attn: - src = src + self.attention_squeeze1(src, attn_weights[1:2]) + src = src + self.attention_squeeze1(src, squeeze_weights) if torch.jit.is_scripting() or use_self_attn: src = src + self.self_attn( @@ -500,8 +512,7 @@ class ZipformerEncoderLayer(nn.Module): # pooling module if torch.jit.is_scripting() or use_self_attn: - src = src + self.attention_squeeze2(src, attn_weights[2:3]) - + src = src + self.attention_squeeze2(src, squeeze_weights) src = self.norm_final(self.balancer(src))