From 9138695dfecb3e27c19cdb2f15135491bd62b190 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 23 Nov 2022 17:04:17 +0800 Subject: [PATCH] Fix bug RE attn_weights --- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 1514ac678..c8b2c9154 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -480,11 +480,13 @@ class ZipformerEncoderLayer(nn.Module): key_padding_mask=src_key_padding_mask, ) - first_attn_weights = attn_weights[0:1] + first_attn_weights = attn_weights[0:3] if random.random() < float(self.const_attention_rate): # Make attention weights constant. The intention is to # encourage these modules to do something similar to an # averaging-over-time operation. + # only need the mask, can just use the 1st one and expand later + first_attn_weights = first_attn_weights[0:1] first_attn_weights = (first_attn_weights > 0.0).to(first_attn_weights.dtype) first_attn_weights = first_attn_weights * (1.0 / first_attn_weights.sum(dim=-1, keepdim=True)) first_attn_weights = first_attn_weights.expand(3, -1, -1, -1)