Update alignment_attention_module.py

This commit is contained in:
zr_jin 2023-07-23 20:24:58 +08:00
parent 7e5c7e6f77
commit 17ad6c2959

View File

@ -179,14 +179,11 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
if use_pos_scores:
pos_emb = self.linear_pos(pos_emb)
print("pos_emb before proj", pos_emb.shape)
seq_len2 = 2 * seq_len - 1
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
2, 0, 3, 1
)
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
print("p", p.shape)
print("pos_emb after proj", pos_emb.shape)
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
# [where seq_len2 represents relative position.]