mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Modified attention.
This commit is contained in:
parent
8e6fd97c6b
commit
dd2acd89fd
@ -449,6 +449,12 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||
|
||||
# Before applying the softmax to the matrix from the dot-product,
|
||||
# we multiply that matrix with this one
|
||||
self.pre_softmax_param = nn.Parameter(
|
||||
torch.Tensor(num_heads, num_heads)
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self) -> None:
|
||||
@ -459,6 +465,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
stdv = 1.0 / math.sqrt(self.num_heads)
|
||||
nn.init.normal_(self.pre_softmax_param, mean=0, std=stdv)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
@ -780,6 +789,19 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
bsz * num_heads, tgt_len, -1
|
||||
)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz, num_heads, tgt_len, -1
|
||||
).permute(0, 2, 3, 1)
|
||||
# now attn_output_weights is of shape (bsz, tgt_len, src_len, num_heads)
|
||||
|
||||
attn_output_weights = torch.matmul(
|
||||
attn_output_weights, self.pre_softmax_param
|
||||
)
|
||||
|
||||
attn_output_weights = attn_output_weights.permute(0, 3, 1, 2).reshape(
|
||||
bsz * num_heads, tgt_len, -1
|
||||
)
|
||||
|
||||
assert list(attn_output_weights.size()) == [
|
||||
bsz * num_heads,
|
||||
tgt_len,
|
||||
|
Loading…
x
Reference in New Issue
Block a user