mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-28 11:14:19 +00:00
Modified attention.
This commit is contained in:
parent
8e6fd97c6b
commit
dd2acd89fd
@ -449,6 +449,12 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||||
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||||
|
|
||||||
|
# Before applying the softmax to the matrix from the dot-product,
|
||||||
|
# we multiply that matrix with this one
|
||||||
|
self.pre_softmax_param = nn.Parameter(
|
||||||
|
torch.Tensor(num_heads, num_heads)
|
||||||
|
)
|
||||||
|
|
||||||
self._reset_parameters()
|
self._reset_parameters()
|
||||||
|
|
||||||
def _reset_parameters(self) -> None:
|
def _reset_parameters(self) -> None:
|
||||||
@ -459,6 +465,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
nn.init.xavier_uniform_(self.pos_bias_u)
|
nn.init.xavier_uniform_(self.pos_bias_u)
|
||||||
nn.init.xavier_uniform_(self.pos_bias_v)
|
nn.init.xavier_uniform_(self.pos_bias_v)
|
||||||
|
|
||||||
|
stdv = 1.0 / math.sqrt(self.num_heads)
|
||||||
|
nn.init.normal_(self.pre_softmax_param, mean=0, std=stdv)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
query: Tensor,
|
query: Tensor,
|
||||||
@ -780,6 +789,19 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
bsz * num_heads, tgt_len, -1
|
bsz * num_heads, tgt_len, -1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
bsz, num_heads, tgt_len, -1
|
||||||
|
).permute(0, 2, 3, 1)
|
||||||
|
# now attn_output_weights is of shape (bsz, tgt_len, src_len, num_heads)
|
||||||
|
|
||||||
|
attn_output_weights = torch.matmul(
|
||||||
|
attn_output_weights, self.pre_softmax_param
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output_weights = attn_output_weights.permute(0, 3, 1, 2).reshape(
|
||||||
|
bsz * num_heads, tgt_len, -1
|
||||||
|
)
|
||||||
|
|
||||||
assert list(attn_output_weights.size()) == [
|
assert list(attn_output_weights.size()) == [
|
||||||
bsz * num_heads,
|
bsz * num_heads,
|
||||||
tgt_len,
|
tgt_len,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user