mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make pos_emb be dropped out independently across batch
This commit is contained in:
parent
4988c815c9
commit
e67d4ca40d
@ -1052,7 +1052,8 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
+ x.size(0),
|
||||
:
|
||||
]
|
||||
pos_emb = pos_emb.unsqueeze(0) # now: (1, 2*time-1, embed_dim)
|
||||
batch_size = x.size(1)
|
||||
pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1) # now: (batch_size, 2*time-1, embed_dim)
|
||||
return self.dropout(pos_emb)
|
||||
|
||||
|
||||
@ -1176,8 +1177,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
|
||||
|
||||
seq_len2 = 2 * seq_len - 1
|
||||
pos_emb = pos_emb.reshape(1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
|
||||
# pos shape now: (head, 1, pos_dim, seq_len2)
|
||||
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
|
||||
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
||||
|
||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||
# [where seq_len2 represents relative position.]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user