Make pos_emb be dropped out independently across batch

This commit is contained in:
Daniel Povey 2022-11-12 19:21:29 +08:00
parent 4988c815c9
commit e67d4ca40d

View File

@ -1052,7 +1052,8 @@ class RelPositionalEncoding(torch.nn.Module):
+ x.size(0), + x.size(0),
: :
] ]
pos_emb = pos_emb.unsqueeze(0) # now: (1, 2*time-1, embed_dim) batch_size = x.size(1)
pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1) # now: (batch_size, 2*time-1, embed_dim)
return self.dropout(pos_emb) return self.dropout(pos_emb)
@ -1176,8 +1177,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
seq_len2 = 2 * seq_len - 1 seq_len2 = 2 * seq_len - 1
pos_emb = pos_emb.reshape(1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
# pos shape now: (head, 1, pos_dim, seq_len2) # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
# [where seq_len2 represents relative position.] # [where seq_len2 represents relative position.]