diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 0ca0fcaa4..032262e76 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1598,6 +1598,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): assert attn_scores.shape == (num_heads, new_batch_size, time1, time2) + assert attn_mask is None if attn_mask is not None: # TODO: assert attn_mask.dtype == torch.bool @@ -1607,9 +1608,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # compares the final weights with zero. attn_scores = attn_scores.masked_fill(attn_mask, -1000) - assert key_padding_mask is not None - assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape - attn_offsets = (~key_padding_mask).float() # 0 at padding positions + # Used to mask out the padding positions + attn_offsets = torch.ones(batch_size, seq_len, device=x.device) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape + attn_offsets = attn_offsets.masked_fill(key_padding_mask, 0.0) # 0 at padding positions # (seq_len, batch, 1) attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(-1) @@ -1619,14 +1623,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module): kernel=block_size * 3, stride=block_size, padding=block_size, ).squeeze(-1) - # For the blocks are all padding + # Used for the blocks are all padding all_pad_mask = (attn_offsets.sum(dim=0, keepdim=True) == 0) # (1, new_batch_size) all_pad_mask = all_pad_mask.unsqueeze(-1).unsqueeze(-1) # (1, new_batch_size, 1, 1) attn_offsets = 1 - attn_offsets # 1 at padding positions - # attn_offsets[attn_offsets != 0] = float("-inf") attn_offsets[attn_offsets != 0] = -1000 - # attn_offsets = attn_offsets.masked_fill((attn_offsets != 0), -1000) # (1, new_batch_size, 1, time2) attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(1).unsqueeze(0)