Reduce attention_dim to 192; cherry-pick scaled_adam_exp130 which is linear_pos interacting with query

This commit is contained in:
Daniel Povey 2022-10-17 19:28:28 +08:00
parent 03fe1ed200
commit 3f495cd197
2 changed files with 52 additions and 59 deletions

View File

@ -836,11 +836,15 @@ class RelPositionMultiheadAttention(nn.Module):
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = attention_dim // num_heads
assert self.head_dim % 2 == 0, self.head_dim
assert (
self.head_dim * num_heads == attention_dim
), "embed_dim//2 must be divisible by num_heads"
)
self.in_proj = nn.Linear(embed_dim, 3 * attention_dim, bias=True)
# the initial_scale is supposed to take over the "scaling" factor of
# head_dim ** -0.5, dividing it between the query and key.
self.in_proj = ScaledLinear(embed_dim, 7 * attention_dim // 2, bias=True,
initial_scale=self.head_dim**-0.25)
# self.whiten_values is applied on the values in forward()
self.whiten_values = Whiten(num_groups=num_heads,
@ -854,7 +858,12 @@ class RelPositionMultiheadAttention(nn.Module):
grad_scale=0.025)
self.in_balancer = ActivationBalancer(3 * attention_dim,
# linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim, attention_dim // 2, bias=False,
initial_scale=0.05)
self.in_balancer = ActivationBalancer(7 * attention_dim // 2,
channel_dim=-1, max_abs=5.0)
self.out_proj = ScaledLinear(
attention_dim, embed_dim, bias=True, initial_scale=0.05
@ -869,12 +878,6 @@ class RelPositionMultiheadAttention(nn.Module):
prob=(0.025, 0.25),
grad_scale=0.025)
# linear transformation for positional encoding (projects to a scalar per head,
# which will be added to the score).
self.linear_pos = ScaledLinear(embed_dim, num_heads, initial_scale=0.05)
# linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, num_heads, bias=False)
def forward(
self,
@ -936,39 +939,6 @@ class RelPositionMultiheadAttention(nn.Module):
)
return x, weights
def rel_shift(self, pos_bias: Tensor) -> Tensor:
"""Convert relative positional bias from linear to matrix format.
Args:
pos_bias: Input tensor (1, 2*T-1, num_heads), where T is the number of frames.
Returns:
Tensor of shape (1, num_heads, time1, time2)
(note: time2 has the same value as time1, but it is for
the key, while time1 is for the query).
"""
(batch_size, n, num_heads) = pos_bias.shape
assert batch_size == 1
T = (n + 1) // 2
assert n == 2 * T - 1
# The leading T dimension behaves like a batch dimension.
# It is only needed because PyTorch does not currently support
# negative strides.
pos_bias = pos_bias.expand(T, n, num_heads).contiguous()
# Note: TorchScript requires explicit arg for stride()
batch_stride = pos_bias.stride(0)
time_stride = pos_bias.stride(1)
head_stride = pos_bias.stride(2)
# We could have left the batch dim as 1, and used '-time_stride' below
# where we use 'batch_stride - time_stride', but PyTorch does not support negative
# strides.
return pos_bias.as_strided(
(1, num_heads, T, T),
(0, head_stride, batch_stride - time_stride, time_stride),
storage_offset=time_stride * (T - 1),
)
def multi_head_attention_forward(
self,
@ -1003,10 +973,10 @@ class RelPositionMultiheadAttention(nn.Module):
Shape:
Inputs:
- x: :math:`(L, N, 3 * A)` where L is the target sequence length, N is the batch size, A is
the attention dimension. Will be split into (query, key, value).
- pos: :math:`(N, 2*L-1, H)` or :math:`(1, 2*L-1, H)` where L is the sequence
length, N is the batch size, and H is the number of heads.
- x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is
the attention dimension. Will be split into (query, key, value, pos).
- pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence
length, N is the batch size, and A is the attention dim.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
will be unchanged. If a BoolTensor is provided, the positions with the
@ -1033,11 +1003,13 @@ class RelPositionMultiheadAttention(nn.Module):
head_dim * num_heads == attention_dim
), "attention_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
# self-attention
q, k, v = x.chunk(3, dim=-1)
q = x[...,:attention_dim]
k = x[...,attention_dim:2*attention_dim]
v = x[...,2*attention_dim:3*attention_dim]
p = x[...,3*attention_dim:]
k = self.whiten_keys(k) # does nothing in the forward pass.
v = self.whiten_values(v) # does nothing in the forward pass.
@ -1091,9 +1063,11 @@ class RelPositionMultiheadAttention(nn.Module):
)
key_padding_mask = key_padding_mask.to(torch.bool)
q = (q * scaling).contiguous().view(seq_len, bsz, num_heads, head_dim)
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
q = q.reshape(seq_len, bsz, num_heads, head_dim)
p = p.reshape(seq_len, bsz, num_heads, head_dim // 2)
k = k.reshape(seq_len, bsz, num_heads, head_dim)
v = v.reshape(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
@ -1103,14 +1077,33 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask.size(1), seq_len
)
q = q.permute(1, 2, 0, 3) # (batch head, time1, head_dim)
q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim)
p = p.permute(1, 2, 0, 3) # (batch, head, time1, head_dim // 2)
# compute attention score
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
# pos_bias: (batch, head, time1, time2)
pos_bias = self.rel_shift(pos)
attn_output_weights = torch.matmul(q, k) + pos_bias
T2 = 2 * seq_len - 1
pos = pos.reshape(1, T2, num_heads, head_dim // 2).permute(0, 2, 3, 1)
# pos shape now: (batch, head, head_dim//2, T2)
# (batch, head, time1, head_dim // 2) x (1, head, head_dim//2, T2) -> (batch, head, time1, T2)
# [where T2 represents relative position.]
pos_weights = torch.matmul(p, pos)
# the following .as_strided() expression converts the last axis of pos_weights from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be.
pos_weights = pos_weights.as_strided((bsz, num_heads, seq_len, seq_len),
(pos_weights.stride(0),
pos_weights.stride(1),
pos_weights.stride(2)-pos_weights.stride(3),
pos_weights.stride(3)),
storage_offset=pos_weights.stride(3) * (seq_len - 1))
attn_output_weights = torch.matmul(q, k) + pos_weights
# attn_output_weights: (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(
@ -1180,7 +1173,7 @@ class RelPositionMultiheadAttention(nn.Module):
# v: (tgt_len, bsz, embed_dim // 2)
v = self.in_proj2(x)
v = self.whiten_values2(v) # does nothing in the forward pass.
v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
v = v.reshape(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
# now v: (bsz * num_heads, seq_len, head_dim)
attn_output = torch.bmm(attn_weights, v)
@ -1450,7 +1443,7 @@ class Conv2dSubsampling(nn.Module):
x = self.conv(x)
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
x = self.out(x.transpose(1, 2).reshape(b, t, c * f))
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.dropout(x)
return x

View File

@ -120,7 +120,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--attention-dims",
type=str,
default="256,256",
default="192,192",
help="Attention dimension in the 2 blocks of conformer encoder layers, comma separated"
)