Merging 109: linear positional encoding

This commit is contained in:
Daniel Povey 2022-10-15 12:58:59 +08:00
commit a0ef291f95

View File

@ -263,7 +263,7 @@ class ConformerEncoderLayer(nn.Module):
self.d_model = d_model
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=dropout,
d_model, nhead, dropout=0.0,
)
self.feed_forward1 = FeedforwardModule(d_model,
@ -436,16 +436,14 @@ class ConformerEncoder(nn.Module):
def get_layers_to_drop(self, rnd_seed: int, warmup_count: float):
num_layers = len(self.layers)
warmup_begin = self.warmup_begin
warmup_end = self.warmup_end
def get_layerdrop_prob(layer: int) -> float:
layer_warmup_delta = (warmup_end - warmup_begin) / num_layers
layer_warmup_begin = warmup_begin + layer * layer_warmup_delta
layer_warmup_begin = self.layers[layer].warmup_begin
layer_warmup_end = self.layers[layer].warmup_end
initial_layerdrop_prob = 0.5
final_layerdrop_prob = 0.05
layer_warmup_end = layer_warmup_begin + layer_warmup_delta
if warmup_count < layer_warmup_begin:
return initial_layerdrop_prob
elif warmup_count > layer_warmup_end:
@ -483,7 +481,7 @@ class ConformerEncoder(nn.Module):
if len(ans) == num_to_drop:
break
if shared_rng.random() < 0.005 or __name__ == "__main__":
logging.info(f"warmup_begin={warmup_begin:.1f}, warmup_end={warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}")
logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}")
return ans
@ -839,9 +837,6 @@ class RelPositionMultiheadAttention(nn.Module):
channel_dim=-1, max_abs=5.0)
self.in_max_eig = MaxEig(3 * embed_dim // 2,
channel_dim=-1)
self.proj_balancer = ActivationBalancer(embed_dim // 2,
channel_dim=-1, max_abs=10.0,
min_positive=0.0, max_positive=1.0)
self.out_proj = ScaledLinear(
embed_dim // 2, embed_dim, bias=True, initial_scale=0.05
)
@ -850,18 +845,9 @@ class RelPositionMultiheadAttention(nn.Module):
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
initial_scale=0.05)
# linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, embed_dim // 2, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self._reset_parameters()
def _reset_parameters(self) -> None:
nn.init.uniform_(self.pos_bias_u, -0.05, 0.05)
nn.init.uniform_(self.pos_bias_v, -0.05, 0.05)
# linear transformation for positional encoding (projects to a scalar per head,
# which will be added to the score).
self.linear_pos = ScaledLinear(embed_dim, num_heads, initial_scale=0.05)
def forward(
self,
@ -909,7 +895,7 @@ class RelPositionMultiheadAttention(nn.Module):
"""
x, weights = self.multi_head_attention_forward(
self.in_max_eig(self.in_balancer(self.in_proj(x))),
pos_emb,
self.linear_pos(pos_emb),
self.embed_dim,
self.num_heads,
self.in_proj.weight,
@ -923,35 +909,44 @@ class RelPositionMultiheadAttention(nn.Module):
)
return x, weights
def rel_shift(self, x: Tensor) -> Tensor:
"""Compute relative positional encoding.
def rel_shift(self, pos_bias: Tensor) -> Tensor:
"""Convert relative positional bias from linear to matrix format.
Args:
x: Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector.
pos_bias: Input tensor (1, 2*T-1, num_heads), where T is the number of frames.
Returns:
Tensor: tensor of shape (batch, head, time1, time2)
(note: time2 has the same value as time1, but it is for
the key, while time1 is for the query).
Tensor of shape (1, num_heads, time1, time2)
(note: time2 has the same value as time1, but it is for
the key, while time1 is for the query).
"""
(batch_size, num_heads, time1, n) = x.shape
assert n == 2 * time1 - 1
(batch_size, n, num_heads) = pos_bias.shape
assert batch_size == 1
T = (n + 1) // 2
assert n == 2 * T - 1
# The leading T dimension behaves like a batch dimension.
# It is only needed because PyTorch does not currently support
# negative strides.
pos_bias = pos_bias.expand(T, n, num_heads).contiguous()
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time1),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
batch_stride = pos_bias.stride(0)
time_stride = pos_bias.stride(1)
head_stride = pos_bias.stride(2)
# We could have left the batch dim as 1, and used '-time_stride' below
# where we use 'batch_stride - time_stride', but PyTorch does not support negative
# strides.
return pos_bias.as_strided(
(1, num_heads, T, T),
(0, head_stride, batch_stride - time_stride, time_stride),
storage_offset=time_stride * (T - 1),
)
def multi_head_attention_forward(
self,
x: Tensor,
pos_emb: Tensor,
pos: Tensor,
embed_dim: int,
num_heads: int,
in_proj_weight: Tensor,
@ -965,8 +960,8 @@ class RelPositionMultiheadAttention(nn.Module):
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor
x_proj: the projected input, to be split into query, key, value.
pos: head-specific biases arising from the positional embeddings.
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
@ -981,14 +976,10 @@ class RelPositionMultiheadAttention(nn.Module):
Shape:
Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
length, N is the batch size, E is the embedding dimension.
- x: :math:`(L, N, 3 * E//2)` where L is the target sequence length, N is the batch size, E is
the embedding dimension. Will be split into (query, key, value).
- pos: :math:`(N, 2*L-1, H)` or :math:`(1, 2*L-1, H)` where L is the sequence
length, N is the batch size, and H is the number of heads.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
will be unchanged. If a BoolTensor is provided, the positions with the
@ -1008,7 +999,7 @@ class RelPositionMultiheadAttention(nn.Module):
H is the num-heads, S is the sequence length.
"""
tgt_len, bsz, _ = x.size()
seq_len, bsz, _ = x.size()
head_dim = embed_dim // (num_heads * 2)
assert (
@ -1040,15 +1031,15 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, tgt_len, tgt_len]:
if list(attn_mask.size()) != [1, seq_len, seq_len]:
raise RuntimeError(
"The size of the 2D attn_mask is not correct."
)
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
tgt_len,
tgt_len,
seq_len,
seq_len,
]:
raise RuntimeError(
"The size of the 3D attn_mask is not correct."
@ -1071,63 +1062,37 @@ class RelPositionMultiheadAttention(nn.Module):
)
key_padding_mask = key_padding_mask.to(torch.bool)
q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim)
q = (q * scaling).contiguous().view(seq_len, bsz, num_heads, head_dim)
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
src_len = k.size(0)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
key_padding_mask.size(0), bsz
)
assert key_padding_mask.size(1) == src_len, "{} == {}".format(
key_padding_mask.size(1), src_len
assert key_padding_mask.size(1) == seq_len, "{} == {}".format(
key_padding_mask.size(1), seq_len
)
q = q.transpose(0, 1) # (batch, time1, head, d_k)
pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.proj_balancer(self.linear_pos(pos_emb)).view(pos_emb_bsz, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(
1, 2
) # (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(
1, 2
) # (batch, head, time1, d_k)
q = q.permute(1, 2, 0, 3) # (batch head, time1, head_dim)
# compute attention score
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
q_with_bias_v, p.transpose(-2, -1)
) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)
attn_output_weights = (
matrix_ac + matrix_bd
) # (batch, head, time1, time2)
# pos_bias: (batch, head, time1, time2)
pos_bias = self.rel_shift(pos)
attn_output_weights = torch.matmul(q, k) + pos_bias
# attn_output_weights: (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
bsz * num_heads, seq_len, seq_len
)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
tgt_len,
src_len,
seq_len,
seq_len,
]
if attn_mask is not None:
@ -1138,14 +1103,14 @@ class RelPositionMultiheadAttention(nn.Module):
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
bsz, num_heads, seq_len, seq_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
bsz * num_heads, seq_len, seq_len
)
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
@ -1154,11 +1119,11 @@ class RelPositionMultiheadAttention(nn.Module):
)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim]
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, embed_dim // 2)
.view(seq_len, bsz, embed_dim // 2)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias