Merging 109: linear positional encoding

This commit is contained in:
Daniel Povey 2022-10-15 12:58:59 +08:00
commit a0ef291f95

View File

@ -263,7 +263,7 @@ class ConformerEncoderLayer(nn.Module):
self.d_model = d_model self.d_model = d_model
self.self_attn = RelPositionMultiheadAttention( self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=dropout, d_model, nhead, dropout=0.0,
) )
self.feed_forward1 = FeedforwardModule(d_model, self.feed_forward1 = FeedforwardModule(d_model,
@ -436,16 +436,14 @@ class ConformerEncoder(nn.Module):
def get_layers_to_drop(self, rnd_seed: int, warmup_count: float): def get_layers_to_drop(self, rnd_seed: int, warmup_count: float):
num_layers = len(self.layers) num_layers = len(self.layers)
warmup_begin = self.warmup_begin
warmup_end = self.warmup_end
def get_layerdrop_prob(layer: int) -> float: def get_layerdrop_prob(layer: int) -> float:
layer_warmup_delta = (warmup_end - warmup_begin) / num_layers layer_warmup_begin = self.layers[layer].warmup_begin
layer_warmup_begin = warmup_begin + layer * layer_warmup_delta layer_warmup_end = self.layers[layer].warmup_end
initial_layerdrop_prob = 0.5 initial_layerdrop_prob = 0.5
final_layerdrop_prob = 0.05 final_layerdrop_prob = 0.05
layer_warmup_end = layer_warmup_begin + layer_warmup_delta
if warmup_count < layer_warmup_begin: if warmup_count < layer_warmup_begin:
return initial_layerdrop_prob return initial_layerdrop_prob
elif warmup_count > layer_warmup_end: elif warmup_count > layer_warmup_end:
@ -483,7 +481,7 @@ class ConformerEncoder(nn.Module):
if len(ans) == num_to_drop: if len(ans) == num_to_drop:
break break
if shared_rng.random() < 0.005 or __name__ == "__main__": if shared_rng.random() < 0.005 or __name__ == "__main__":
logging.info(f"warmup_begin={warmup_begin:.1f}, warmup_end={warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}")
return ans return ans
@ -839,9 +837,6 @@ class RelPositionMultiheadAttention(nn.Module):
channel_dim=-1, max_abs=5.0) channel_dim=-1, max_abs=5.0)
self.in_max_eig = MaxEig(3 * embed_dim // 2, self.in_max_eig = MaxEig(3 * embed_dim // 2,
channel_dim=-1) channel_dim=-1)
self.proj_balancer = ActivationBalancer(embed_dim // 2,
channel_dim=-1, max_abs=10.0,
min_positive=0.0, max_positive=1.0)
self.out_proj = ScaledLinear( self.out_proj = ScaledLinear(
embed_dim // 2, embed_dim, bias=True, initial_scale=0.05 embed_dim // 2, embed_dim, bias=True, initial_scale=0.05
) )
@ -850,18 +845,9 @@ class RelPositionMultiheadAttention(nn.Module):
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True, self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
initial_scale=0.05) initial_scale=0.05)
# linear transformation for positional encoding (projects to a scalar per head,
# linear transformation for positional encoding. # which will be added to the score).
self.linear_pos = nn.Linear(embed_dim, embed_dim // 2, bias=False) self.linear_pos = ScaledLinear(embed_dim, num_heads, initial_scale=0.05)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self._reset_parameters()
def _reset_parameters(self) -> None:
nn.init.uniform_(self.pos_bias_u, -0.05, 0.05)
nn.init.uniform_(self.pos_bias_v, -0.05, 0.05)
def forward( def forward(
self, self,
@ -909,7 +895,7 @@ class RelPositionMultiheadAttention(nn.Module):
""" """
x, weights = self.multi_head_attention_forward( x, weights = self.multi_head_attention_forward(
self.in_max_eig(self.in_balancer(self.in_proj(x))), self.in_max_eig(self.in_balancer(self.in_proj(x))),
pos_emb, self.linear_pos(pos_emb),
self.embed_dim, self.embed_dim,
self.num_heads, self.num_heads,
self.in_proj.weight, self.in_proj.weight,
@ -923,35 +909,44 @@ class RelPositionMultiheadAttention(nn.Module):
) )
return x, weights return x, weights
def rel_shift(self, x: Tensor) -> Tensor: def rel_shift(self, pos_bias: Tensor) -> Tensor:
"""Compute relative positional encoding. """Convert relative positional bias from linear to matrix format.
Args: Args:
x: Input tensor (batch, head, time1, 2*time1-1). pos_bias: Input tensor (1, 2*T-1, num_heads), where T is the number of frames.
time1 means the length of query vector.
Returns: Returns:
Tensor: tensor of shape (batch, head, time1, time2) Tensor of shape (1, num_heads, time1, time2)
(note: time2 has the same value as time1, but it is for (note: time2 has the same value as time1, but it is for
the key, while time1 is for the query). the key, while time1 is for the query).
""" """
(batch_size, num_heads, time1, n) = x.shape (batch_size, n, num_heads) = pos_bias.shape
assert n == 2 * time1 - 1 assert batch_size == 1
T = (n + 1) // 2
assert n == 2 * T - 1
# The leading T dimension behaves like a batch dimension.
# It is only needed because PyTorch does not currently support
# negative strides.
pos_bias = pos_bias.expand(T, n, num_heads).contiguous()
# Note: TorchScript requires explicit arg for stride() # Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0) batch_stride = pos_bias.stride(0)
head_stride = x.stride(1) time_stride = pos_bias.stride(1)
time1_stride = x.stride(2) head_stride = pos_bias.stride(2)
n_stride = x.stride(3)
return x.as_strided( # We could have left the batch dim as 1, and used '-time_stride' below
(batch_size, num_heads, time1, time1), # where we use 'batch_stride - time_stride', but PyTorch does not support negative
(batch_stride, head_stride, time1_stride - n_stride, n_stride), # strides.
storage_offset=n_stride * (time1 - 1), return pos_bias.as_strided(
(1, num_heads, T, T),
(0, head_stride, batch_stride - time_stride, time_stride),
storage_offset=time_stride * (T - 1),
) )
def multi_head_attention_forward( def multi_head_attention_forward(
self, self,
x: Tensor, x: Tensor,
pos_emb: Tensor, pos: Tensor,
embed_dim: int, embed_dim: int,
num_heads: int, num_heads: int,
in_proj_weight: Tensor, in_proj_weight: Tensor,
@ -965,8 +960,8 @@ class RelPositionMultiheadAttention(nn.Module):
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
r""" r"""
Args: Args:
query, key, value: map a query and a set of key-value pairs to an output. x_proj: the projected input, to be split into query, key, value.
pos_emb: Positional embedding tensor pos: head-specific biases arising from the positional embeddings.
embed_dim: total dimension of the model. embed_dim: total dimension of the model.
num_heads: parallel attention heads. num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias. in_proj_weight, in_proj_bias: input projection weight and bias.
@ -981,14 +976,10 @@ class RelPositionMultiheadAttention(nn.Module):
Shape: Shape:
Inputs: Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - x: :math:`(L, N, 3 * E//2)` where L is the target sequence length, N is the batch size, E is
the embedding dimension. the embedding dimension. Will be split into (query, key, value).
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - pos: :math:`(N, 2*L-1, H)` or :math:`(1, 2*L-1, H)` where L is the sequence
the embedding dimension. length, N is the batch size, and H is the number of heads.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
length, N is the batch size, E is the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
will be unchanged. If a BoolTensor is provided, the positions with the will be unchanged. If a BoolTensor is provided, the positions with the
@ -1008,7 +999,7 @@ class RelPositionMultiheadAttention(nn.Module):
H is the num-heads, S is the sequence length. H is the num-heads, S is the sequence length.
""" """
tgt_len, bsz, _ = x.size() seq_len, bsz, _ = x.size()
head_dim = embed_dim // (num_heads * 2) head_dim = embed_dim // (num_heads * 2)
assert ( assert (
@ -1040,15 +1031,15 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0) attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, tgt_len, tgt_len]: if list(attn_mask.size()) != [1, seq_len, seq_len]:
raise RuntimeError( raise RuntimeError(
"The size of the 2D attn_mask is not correct." "The size of the 2D attn_mask is not correct."
) )
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [ if list(attn_mask.size()) != [
bsz * num_heads, bsz * num_heads,
tgt_len, seq_len,
tgt_len, seq_len,
]: ]:
raise RuntimeError( raise RuntimeError(
"The size of the 3D attn_mask is not correct." "The size of the 3D attn_mask is not correct."
@ -1071,63 +1062,37 @@ class RelPositionMultiheadAttention(nn.Module):
) )
key_padding_mask = key_padding_mask.to(torch.bool) key_padding_mask = key_padding_mask.to(torch.bool)
q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) q = (q * scaling).contiguous().view(seq_len, bsz, num_heads, head_dim)
k = k.contiguous().view(-1, bsz, num_heads, head_dim) k = k.contiguous().view(-1, bsz, num_heads, head_dim)
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
src_len = k.size(0)
if key_padding_mask is not None: if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz, "{} == {}".format( assert key_padding_mask.size(0) == bsz, "{} == {}".format(
key_padding_mask.size(0), bsz key_padding_mask.size(0), bsz
) )
assert key_padding_mask.size(1) == src_len, "{} == {}".format( assert key_padding_mask.size(1) == seq_len, "{} == {}".format(
key_padding_mask.size(1), src_len key_padding_mask.size(1), seq_len
) )
q = q.transpose(0, 1) # (batch, time1, head, d_k) q = q.permute(1, 2, 0, 3) # (batch head, time1, head_dim)
pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.proj_balancer(self.linear_pos(pos_emb)).view(pos_emb_bsz, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(
1, 2
) # (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(
1, 2
) # (batch, head, time1, d_k)
# compute attention score # compute attention score
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
q_with_bias_v, p.transpose(-2, -1)
) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)
attn_output_weights = (
matrix_ac + matrix_bd
) # (batch, head, time1, time2)
# pos_bias: (batch, head, time1, time2)
pos_bias = self.rel_shift(pos)
attn_output_weights = torch.matmul(q, k) + pos_bias
# attn_output_weights: (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1 bsz * num_heads, seq_len, seq_len
) )
assert list(attn_output_weights.size()) == [ assert list(attn_output_weights.size()) == [
bsz * num_heads, bsz * num_heads,
tgt_len, seq_len,
src_len, seq_len,
] ]
if attn_mask is not None: if attn_mask is not None:
@ -1138,14 +1103,14 @@ class RelPositionMultiheadAttention(nn.Module):
if key_padding_mask is not None: if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len bsz, num_heads, seq_len, seq_len
) )
attn_output_weights = attn_output_weights.masked_fill( attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"), float("-inf"),
) )
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len bsz * num_heads, seq_len, seq_len
) )
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
@ -1154,11 +1119,11 @@ class RelPositionMultiheadAttention(nn.Module):
) )
attn_output = torch.bmm(attn_output_weights, v) attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim]
attn_output = ( attn_output = (
attn_output.transpose(0, 1) attn_output.transpose(0, 1)
.contiguous() .contiguous()
.view(tgt_len, bsz, embed_dim // 2) .view(seq_len, bsz, embed_dim // 2)
) )
attn_output = nn.functional.linear( attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias attn_output, out_proj_weight, out_proj_bias