Let ratio be 8, not 2, for sigmoid in NonlinAttentionModule

This commit is contained in:
Daniel Povey 2022-11-28 21:51:20 +08:00
parent 7018c722b5
commit 258d4f1353

View File

@ -1463,6 +1463,7 @@ class NonlinAttentionModule(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.ratio = ratio
assert channels % ratio == 0 assert channels % ratio == 0
self.in_proj = nn.Linear(channels, channels + channels // ratio, bias=True) self.in_proj = nn.Linear(channels, channels + channels // ratio, bias=True)
@ -1513,7 +1514,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
s = self.balancer(s) s = self.balancer(s)
s = self.sigmoid(s) s = self.sigmoid(s)
s = s.unsqueeze(-1).expand(-1, -1, -1, ratio).reshape(seq_len, batch_size, num_channels) s = s.unsqueeze(-1).expand(-1, -1, -1, self.ratio).reshape(seq_len, batch_size, num_channels)
x = x * s x = x * s