diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index d20be6145..1514ac678 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -480,13 +480,14 @@ class ZipformerEncoderLayer(nn.Module): key_padding_mask=src_key_padding_mask, ) - first_attn_weights = attn_weights[0:3] + first_attn_weights = attn_weights[0:1] if random.random() < float(self.const_attention_rate): # Make attention weights constant. The intention is to # encourage these modules to do something similar to an # averaging-over-time operation. first_attn_weights = (first_attn_weights > 0.0).to(first_attn_weights.dtype) first_attn_weights = first_attn_weights * (1.0 / first_attn_weights.sum(dim=-1, keepdim=True)) + first_attn_weights = first_attn_weights.expand(3, -1, -1, -1) if torch.jit.is_scripting() or use_self_attn: src = src + self.nonlin_attention_module(src, @@ -1300,7 +1301,7 @@ class AttentionSqueeze(nn.Module): bottleneck_dim, channel_dim=-1, min_positive=0.05, max_positive=0.95, min_abs=0.05, - max_abs=2.0, + max_abs=10.0, max_factor=0.02, min_prob=0.1, )