mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Update alignment_attention_module.py
This commit is contained in:
parent
7e5c7e6f77
commit
17ad6c2959
@ -179,14 +179,11 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
|
|
||||||
if use_pos_scores:
|
if use_pos_scores:
|
||||||
pos_emb = self.linear_pos(pos_emb)
|
pos_emb = self.linear_pos(pos_emb)
|
||||||
print("pos_emb before proj", pos_emb.shape)
|
|
||||||
seq_len2 = 2 * seq_len - 1
|
seq_len2 = 2 * seq_len - 1
|
||||||
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
|
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
|
||||||
2, 0, 3, 1
|
2, 0, 3, 1
|
||||||
)
|
)
|
||||||
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
||||||
print("p", p.shape)
|
|
||||||
print("pos_emb after proj", pos_emb.shape)
|
|
||||||
|
|
||||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||||
# [where seq_len2 represents relative position.]
|
# [where seq_len2 represents relative position.]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user