mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Update alignment_attention_module.py
This commit is contained in:
parent
7e5c7e6f77
commit
17ad6c2959
@ -179,14 +179,11 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
|
||||
if use_pos_scores:
|
||||
pos_emb = self.linear_pos(pos_emb)
|
||||
print("pos_emb before proj", pos_emb.shape)
|
||||
seq_len2 = 2 * seq_len - 1
|
||||
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
|
||||
2, 0, 3, 1
|
||||
)
|
||||
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
||||
print("p", p.shape)
|
||||
print("pos_emb after proj", pos_emb.shape)
|
||||
|
||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||
# [where seq_len2 represents relative position.]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user