mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
simplify relative position encoding
This commit is contained in:
parent
b338471917
commit
5313ce00d6
@ -859,7 +859,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class RelPositionMultiheadAttention(nn.Module):
|
class RelPositionMultiheadAttention(nn.Module):
|
||||||
r"""Multi-Head Attention layer with relative position encoding
|
r"""Multi-Head Attention layer with simplified relative position encoding
|
||||||
|
|
||||||
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||||
|
|
||||||
@ -895,24 +895,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# linear transformation for positional encoding.
|
# linear transformation for positional encoding.
|
||||||
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
self.linear_pos = ScaledLinear(embed_dim, num_heads, bias=True)
|
||||||
# these two learnable bias are used in matrix c and matrix d
|
|
||||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
|
||||||
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
|
||||||
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
|
||||||
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
|
|
||||||
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
|
|
||||||
self._reset_parameters()
|
|
||||||
|
|
||||||
def _pos_bias_u(self):
|
|
||||||
return self.pos_bias_u * self.pos_bias_u_scale.exp()
|
|
||||||
|
|
||||||
def _pos_bias_v(self):
|
|
||||||
return self.pos_bias_v * self.pos_bias_v_scale.exp()
|
|
||||||
|
|
||||||
def _reset_parameters(self) -> None:
|
|
||||||
nn.init.normal_(self.pos_bias_u, std=0.01)
|
|
||||||
nn.init.normal_(self.pos_bias_v, std=0.01)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -1217,35 +1200,25 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
key_padding_mask.size(1), src_len
|
key_padding_mask.size(1), src_len
|
||||||
)
|
)
|
||||||
|
|
||||||
q = q.transpose(0, 1) # (batch, time1, head, d_k)
|
q = q.permute(1, 2, 0, 3) # (batch, head, time1, d_k)
|
||||||
|
|
||||||
pos_emb_bsz = pos_emb.size(0)
|
pos_emb_bsz = pos_emb.size(0)
|
||||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
|
||||||
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
|
|
||||||
p = p.permute(0, 2, 3, 1)
|
|
||||||
|
|
||||||
q_with_bias_u = (q + self._pos_bias_u()).transpose(
|
|
||||||
1, 2
|
|
||||||
) # (batch, head, time1, d_k)
|
|
||||||
|
|
||||||
q_with_bias_v = (q + self._pos_bias_v()).transpose(
|
|
||||||
1, 2
|
|
||||||
) # (batch, head, time1, d_k)
|
|
||||||
|
|
||||||
# compute attention score
|
# compute attention score
|
||||||
# first compute matrix a and matrix c
|
# first compute matrix a and matrix c
|
||||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||||
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
||||||
matrix_ac = torch.matmul(
|
matrix_ac = torch.matmul(q, k) # (batch, head, time1, time2)
|
||||||
q_with_bias_u, k
|
|
||||||
) # (batch, head, time1, time2)
|
|
||||||
|
|
||||||
# compute matrix b and matrix d
|
# compute matrix b and matrix d
|
||||||
matrix_bd = torch.matmul(
|
pos_emb = self.linear_pos(pos_emb) # (1, 2*time1-1, head)
|
||||||
q_with_bias_v, p
|
matrix_bd = (
|
||||||
) # (batch, head, time1, 2*time1-1)
|
pos_emb.transpose(1, 2).unsqueeze(2).repeat(1, 1, tgt_len, 1)
|
||||||
matrix_bd = self.rel_shift(matrix_bd, left_context)
|
) # (1, head, time1, 2*time1-1)
|
||||||
|
matrix_bd = self.rel_shift(
|
||||||
|
matrix_bd, left_context
|
||||||
|
) # (1, head, time1, time2)
|
||||||
|
|
||||||
attn_output_weights = (
|
attn_output_weights = (
|
||||||
matrix_ac + matrix_bd
|
matrix_ac + matrix_bd
|
||||||
|
Loading…
x
Reference in New Issue
Block a user