mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Bug fix
This commit is contained in:
parent
e5666628bd
commit
14a2603ada
@ -509,12 +509,12 @@ class AttentionDownsample(torch.nn.Module):
|
||||
weights = scores.softmax(dim=1)
|
||||
|
||||
# ans1 is the first `in_channels` channels of the output
|
||||
ans1 = (src * weights).sum(dim=1)
|
||||
ans = (src * weights).sum(dim=1)
|
||||
src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels)
|
||||
|
||||
if self.extra_proj is not None:
|
||||
ans2 = self.extra_proj(src)
|
||||
ans = torch.cat((ans1, ans2), dim=2)
|
||||
ans = torch.cat((ans, ans2), dim=2)
|
||||
return ans
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user