Revert making the dropout of pos_emb independent across the batch.

This commit is contained in:
Daniel Povey 2022-11-14 15:34:39 +08:00
parent 804917837e
commit ce4b50d094

View File

@ -984,8 +984,7 @@ class RelPositionalEncoding(torch.nn.Module):
+ x.size(0),
:
]
batch_size = x.size(1)
pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1) # now: (batch_size, 2*time-1, embed_dim)
pos_emb = pos_emb.unsqueeze(0)
return self.dropout(pos_emb)