This commit is contained in:
Daniel Povey 2022-09-28 20:59:24 +08:00
parent e5666628bd
commit 14a2603ada

View File

@ -509,12 +509,12 @@ class AttentionDownsample(torch.nn.Module):
weights = scores.softmax(dim=1)
# ans1 is the first `in_channels` channels of the output
ans1 = (src * weights).sum(dim=1)
ans = (src * weights).sum(dim=1)
src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels)
if self.extra_proj is not None:
ans2 = self.extra_proj(src)
ans = torch.cat((ans1, ans2), dim=2)
ans = torch.cat((ans, ans2), dim=2)
return ans