mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make pos_emb be dropped out independently across batch
This commit is contained in:
parent
4988c815c9
commit
e67d4ca40d
@ -1052,7 +1052,8 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
+ x.size(0),
|
+ x.size(0),
|
||||||
:
|
:
|
||||||
]
|
]
|
||||||
pos_emb = pos_emb.unsqueeze(0) # now: (1, 2*time-1, embed_dim)
|
batch_size = x.size(1)
|
||||||
|
pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1) # now: (batch_size, 2*time-1, embed_dim)
|
||||||
return self.dropout(pos_emb)
|
return self.dropout(pos_emb)
|
||||||
|
|
||||||
|
|
||||||
@ -1176,8 +1177,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
seq_len2 = 2 * seq_len - 1
|
seq_len2 = 2 * seq_len - 1
|
||||||
pos_emb = pos_emb.reshape(1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
|
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
|
||||||
# pos shape now: (head, 1, pos_dim, seq_len2)
|
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
||||||
|
|
||||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||||
# [where seq_len2 represents relative position.]
|
# [where seq_len2 represents relative position.]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user