Fix bug RE attn_weights

This commit is contained in:
Daniel Povey 2022-11-23 17:04:17 +08:00
parent 36e49a8d61
commit 9138695dfe

View File

@ -480,11 +480,13 @@ class ZipformerEncoderLayer(nn.Module):
key_padding_mask=src_key_padding_mask,
)
first_attn_weights = attn_weights[0:1]
first_attn_weights = attn_weights[0:3]
if random.random() < float(self.const_attention_rate):
# Make attention weights constant. The intention is to
# encourage these modules to do something similar to an
# averaging-over-time operation.
# only need the mask, can just use the 1st one and expand later
first_attn_weights = first_attn_weights[0:1]
first_attn_weights = (first_attn_weights > 0.0).to(first_attn_weights.dtype)
first_attn_weights = first_attn_weights * (1.0 / first_attn_weights.sum(dim=-1, keepdim=True))
first_attn_weights = first_attn_weights.expand(3, -1, -1, -1)