mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merging 109: linear positional encoding
This commit is contained in:
commit
a0ef291f95
@ -263,7 +263,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
self.d_model = d_model
|
||||
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
d_model, nhead, dropout=dropout,
|
||||
d_model, nhead, dropout=0.0,
|
||||
)
|
||||
|
||||
self.feed_forward1 = FeedforwardModule(d_model,
|
||||
@ -436,16 +436,14 @@ class ConformerEncoder(nn.Module):
|
||||
def get_layers_to_drop(self, rnd_seed: int, warmup_count: float):
|
||||
|
||||
num_layers = len(self.layers)
|
||||
warmup_begin = self.warmup_begin
|
||||
warmup_end = self.warmup_end
|
||||
|
||||
def get_layerdrop_prob(layer: int) -> float:
|
||||
layer_warmup_delta = (warmup_end - warmup_begin) / num_layers
|
||||
layer_warmup_begin = warmup_begin + layer * layer_warmup_delta
|
||||
layer_warmup_begin = self.layers[layer].warmup_begin
|
||||
layer_warmup_end = self.layers[layer].warmup_end
|
||||
|
||||
initial_layerdrop_prob = 0.5
|
||||
final_layerdrop_prob = 0.05
|
||||
|
||||
layer_warmup_end = layer_warmup_begin + layer_warmup_delta
|
||||
if warmup_count < layer_warmup_begin:
|
||||
return initial_layerdrop_prob
|
||||
elif warmup_count > layer_warmup_end:
|
||||
@ -483,7 +481,7 @@ class ConformerEncoder(nn.Module):
|
||||
if len(ans) == num_to_drop:
|
||||
break
|
||||
if shared_rng.random() < 0.005 or __name__ == "__main__":
|
||||
logging.info(f"warmup_begin={warmup_begin:.1f}, warmup_end={warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}")
|
||||
logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}")
|
||||
return ans
|
||||
|
||||
|
||||
@ -839,9 +837,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
channel_dim=-1, max_abs=5.0)
|
||||
self.in_max_eig = MaxEig(3 * embed_dim // 2,
|
||||
channel_dim=-1)
|
||||
self.proj_balancer = ActivationBalancer(embed_dim // 2,
|
||||
channel_dim=-1, max_abs=10.0,
|
||||
min_positive=0.0, max_positive=1.0)
|
||||
self.out_proj = ScaledLinear(
|
||||
embed_dim // 2, embed_dim, bias=True, initial_scale=0.05
|
||||
)
|
||||
@ -850,18 +845,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
|
||||
initial_scale=0.05)
|
||||
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = nn.Linear(embed_dim, embed_dim // 2, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self) -> None:
|
||||
nn.init.uniform_(self.pos_bias_u, -0.05, 0.05)
|
||||
nn.init.uniform_(self.pos_bias_v, -0.05, 0.05)
|
||||
# linear transformation for positional encoding (projects to a scalar per head,
|
||||
# which will be added to the score).
|
||||
self.linear_pos = ScaledLinear(embed_dim, num_heads, initial_scale=0.05)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -909,7 +895,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
"""
|
||||
x, weights = self.multi_head_attention_forward(
|
||||
self.in_max_eig(self.in_balancer(self.in_proj(x))),
|
||||
pos_emb,
|
||||
self.linear_pos(pos_emb),
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj.weight,
|
||||
@ -923,35 +909,44 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
)
|
||||
return x, weights
|
||||
|
||||
def rel_shift(self, x: Tensor) -> Tensor:
|
||||
"""Compute relative positional encoding.
|
||||
def rel_shift(self, pos_bias: Tensor) -> Tensor:
|
||||
"""Convert relative positional bias from linear to matrix format.
|
||||
|
||||
Args:
|
||||
x: Input tensor (batch, head, time1, 2*time1-1).
|
||||
time1 means the length of query vector.
|
||||
pos_bias: Input tensor (1, 2*T-1, num_heads), where T is the number of frames.
|
||||
|
||||
Returns:
|
||||
Tensor: tensor of shape (batch, head, time1, time2)
|
||||
(note: time2 has the same value as time1, but it is for
|
||||
the key, while time1 is for the query).
|
||||
Tensor of shape (1, num_heads, time1, time2)
|
||||
(note: time2 has the same value as time1, but it is for
|
||||
the key, while time1 is for the query).
|
||||
"""
|
||||
(batch_size, num_heads, time1, n) = x.shape
|
||||
assert n == 2 * time1 - 1
|
||||
(batch_size, n, num_heads) = pos_bias.shape
|
||||
assert batch_size == 1
|
||||
T = (n + 1) // 2
|
||||
assert n == 2 * T - 1
|
||||
# The leading T dimension behaves like a batch dimension.
|
||||
# It is only needed because PyTorch does not currently support
|
||||
# negative strides.
|
||||
pos_bias = pos_bias.expand(T, n, num_heads).contiguous()
|
||||
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
batch_stride = x.stride(0)
|
||||
head_stride = x.stride(1)
|
||||
time1_stride = x.stride(2)
|
||||
n_stride = x.stride(3)
|
||||
return x.as_strided(
|
||||
(batch_size, num_heads, time1, time1),
|
||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||
storage_offset=n_stride * (time1 - 1),
|
||||
batch_stride = pos_bias.stride(0)
|
||||
time_stride = pos_bias.stride(1)
|
||||
head_stride = pos_bias.stride(2)
|
||||
|
||||
# We could have left the batch dim as 1, and used '-time_stride' below
|
||||
# where we use 'batch_stride - time_stride', but PyTorch does not support negative
|
||||
# strides.
|
||||
return pos_bias.as_strided(
|
||||
(1, num_heads, T, T),
|
||||
(0, head_stride, batch_stride - time_stride, time_stride),
|
||||
storage_offset=time_stride * (T - 1),
|
||||
)
|
||||
|
||||
def multi_head_attention_forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
pos_emb: Tensor,
|
||||
pos: Tensor,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
in_proj_weight: Tensor,
|
||||
@ -965,8 +960,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
pos_emb: Positional embedding tensor
|
||||
x_proj: the projected input, to be split into query, key, value.
|
||||
pos: head-specific biases arising from the positional embeddings.
|
||||
embed_dim: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
in_proj_weight, in_proj_bias: input projection weight and bias.
|
||||
@ -981,14 +976,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
Shape:
|
||||
Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
|
||||
length, N is the batch size, E is the embedding dimension.
|
||||
- x: :math:`(L, N, 3 * E//2)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension. Will be split into (query, key, value).
|
||||
- pos: :math:`(N, 2*L-1, H)` or :math:`(1, 2*L-1, H)` where L is the sequence
|
||||
length, N is the batch size, and H is the number of heads.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
||||
will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
@ -1008,7 +999,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
H is the num-heads, S is the sequence length.
|
||||
"""
|
||||
|
||||
tgt_len, bsz, _ = x.size()
|
||||
seq_len, bsz, _ = x.size()
|
||||
|
||||
head_dim = embed_dim // (num_heads * 2)
|
||||
assert (
|
||||
@ -1040,15 +1031,15 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
if attn_mask.dim() == 2:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
if list(attn_mask.size()) != [1, tgt_len, tgt_len]:
|
||||
if list(attn_mask.size()) != [1, seq_len, seq_len]:
|
||||
raise RuntimeError(
|
||||
"The size of the 2D attn_mask is not correct."
|
||||
)
|
||||
elif attn_mask.dim() == 3:
|
||||
if list(attn_mask.size()) != [
|
||||
bsz * num_heads,
|
||||
tgt_len,
|
||||
tgt_len,
|
||||
seq_len,
|
||||
seq_len,
|
||||
]:
|
||||
raise RuntimeError(
|
||||
"The size of the 3D attn_mask is not correct."
|
||||
@ -1071,63 +1062,37 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
)
|
||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||
|
||||
q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim)
|
||||
q = (q * scaling).contiguous().view(seq_len, bsz, num_heads, head_dim)
|
||||
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
|
||||
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
src_len = k.size(0)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
|
||||
key_padding_mask.size(0), bsz
|
||||
)
|
||||
assert key_padding_mask.size(1) == src_len, "{} == {}".format(
|
||||
key_padding_mask.size(1), src_len
|
||||
assert key_padding_mask.size(1) == seq_len, "{} == {}".format(
|
||||
key_padding_mask.size(1), seq_len
|
||||
)
|
||||
|
||||
q = q.transpose(0, 1) # (batch, time1, head, d_k)
|
||||
|
||||
pos_emb_bsz = pos_emb.size(0)
|
||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||
p = self.proj_balancer(self.linear_pos(pos_emb)).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
||||
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(
|
||||
1, 2
|
||||
) # (batch, head, time1, d_k)
|
||||
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose(
|
||||
1, 2
|
||||
) # (batch, head, time1, d_k)
|
||||
|
||||
q = q.permute(1, 2, 0, 3) # (batch head, time1, head_dim)
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
||||
matrix_ac = torch.matmul(
|
||||
q_with_bias_u, k
|
||||
) # (batch, head, time1, time2)
|
||||
|
||||
# compute matrix b and matrix d
|
||||
matrix_bd = torch.matmul(
|
||||
q_with_bias_v, p.transpose(-2, -1)
|
||||
) # (batch, head, time1, 2*time1-1)
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
attn_output_weights = (
|
||||
matrix_ac + matrix_bd
|
||||
) # (batch, head, time1, time2)
|
||||
|
||||
# pos_bias: (batch, head, time1, time2)
|
||||
pos_bias = self.rel_shift(pos)
|
||||
|
||||
attn_output_weights = torch.matmul(q, k) + pos_bias
|
||||
# attn_output_weights: (batch, head, time1, time2)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, -1
|
||||
bsz * num_heads, seq_len, seq_len
|
||||
)
|
||||
|
||||
assert list(attn_output_weights.size()) == [
|
||||
bsz * num_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
seq_len,
|
||||
seq_len,
|
||||
]
|
||||
|
||||
if attn_mask is not None:
|
||||
@ -1138,14 +1103,14 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz, num_heads, tgt_len, src_len
|
||||
bsz, num_heads, seq_len, seq_len
|
||||
)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||
float("-inf"),
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, src_len
|
||||
bsz * num_heads, seq_len, seq_len
|
||||
)
|
||||
|
||||
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
||||
@ -1154,11 +1119,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||
assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim]
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(tgt_len, bsz, embed_dim // 2)
|
||||
.view(seq_len, bsz, embed_dim // 2)
|
||||
)
|
||||
attn_output = nn.functional.linear(
|
||||
attn_output, out_proj_weight, out_proj_bias
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user