mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Reduce attention_dim to 192; cherry-pick scaled_adam_exp130 which is linear_pos interacting with query
This commit is contained in:
parent
03fe1ed200
commit
3f495cd197
@ -836,11 +836,15 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = attention_dim // num_heads
|
||||
assert self.head_dim % 2 == 0, self.head_dim
|
||||
assert (
|
||||
self.head_dim * num_heads == attention_dim
|
||||
), "embed_dim//2 must be divisible by num_heads"
|
||||
)
|
||||
|
||||
self.in_proj = nn.Linear(embed_dim, 3 * attention_dim, bias=True)
|
||||
# the initial_scale is supposed to take over the "scaling" factor of
|
||||
# head_dim ** -0.5, dividing it between the query and key.
|
||||
self.in_proj = ScaledLinear(embed_dim, 7 * attention_dim // 2, bias=True,
|
||||
initial_scale=self.head_dim**-0.25)
|
||||
|
||||
# self.whiten_values is applied on the values in forward()
|
||||
self.whiten_values = Whiten(num_groups=num_heads,
|
||||
@ -854,7 +858,12 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
grad_scale=0.025)
|
||||
|
||||
|
||||
self.in_balancer = ActivationBalancer(3 * attention_dim,
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = ScaledLinear(embed_dim, attention_dim // 2, bias=False,
|
||||
initial_scale=0.05)
|
||||
|
||||
self.in_balancer = ActivationBalancer(7 * attention_dim // 2,
|
||||
channel_dim=-1, max_abs=5.0)
|
||||
self.out_proj = ScaledLinear(
|
||||
attention_dim, embed_dim, bias=True, initial_scale=0.05
|
||||
@ -869,12 +878,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.025)
|
||||
|
||||
# linear transformation for positional encoding (projects to a scalar per head,
|
||||
# which will be added to the score).
|
||||
self.linear_pos = ScaledLinear(embed_dim, num_heads, initial_scale=0.05)
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = nn.Linear(embed_dim, num_heads, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -936,39 +939,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
)
|
||||
return x, weights
|
||||
|
||||
def rel_shift(self, pos_bias: Tensor) -> Tensor:
|
||||
"""Convert relative positional bias from linear to matrix format.
|
||||
|
||||
Args:
|
||||
pos_bias: Input tensor (1, 2*T-1, num_heads), where T is the number of frames.
|
||||
|
||||
Returns:
|
||||
Tensor of shape (1, num_heads, time1, time2)
|
||||
(note: time2 has the same value as time1, but it is for
|
||||
the key, while time1 is for the query).
|
||||
"""
|
||||
(batch_size, n, num_heads) = pos_bias.shape
|
||||
assert batch_size == 1
|
||||
T = (n + 1) // 2
|
||||
assert n == 2 * T - 1
|
||||
# The leading T dimension behaves like a batch dimension.
|
||||
# It is only needed because PyTorch does not currently support
|
||||
# negative strides.
|
||||
pos_bias = pos_bias.expand(T, n, num_heads).contiguous()
|
||||
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
batch_stride = pos_bias.stride(0)
|
||||
time_stride = pos_bias.stride(1)
|
||||
head_stride = pos_bias.stride(2)
|
||||
|
||||
# We could have left the batch dim as 1, and used '-time_stride' below
|
||||
# where we use 'batch_stride - time_stride', but PyTorch does not support negative
|
||||
# strides.
|
||||
return pos_bias.as_strided(
|
||||
(1, num_heads, T, T),
|
||||
(0, head_stride, batch_stride - time_stride, time_stride),
|
||||
storage_offset=time_stride * (T - 1),
|
||||
)
|
||||
|
||||
def multi_head_attention_forward(
|
||||
self,
|
||||
@ -1003,10 +973,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
Shape:
|
||||
Inputs:
|
||||
- x: :math:`(L, N, 3 * A)` where L is the target sequence length, N is the batch size, A is
|
||||
the attention dimension. Will be split into (query, key, value).
|
||||
- pos: :math:`(N, 2*L-1, H)` or :math:`(1, 2*L-1, H)` where L is the sequence
|
||||
length, N is the batch size, and H is the number of heads.
|
||||
- x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is
|
||||
the attention dimension. Will be split into (query, key, value, pos).
|
||||
- pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence
|
||||
length, N is the batch size, and A is the attention dim.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
||||
will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
@ -1033,11 +1003,13 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
head_dim * num_heads == attention_dim
|
||||
), "attention_dim must be divisible by num_heads"
|
||||
|
||||
scaling = float(head_dim) ** -0.5
|
||||
|
||||
|
||||
# self-attention
|
||||
q, k, v = x.chunk(3, dim=-1)
|
||||
q = x[...,:attention_dim]
|
||||
k = x[...,attention_dim:2*attention_dim]
|
||||
v = x[...,2*attention_dim:3*attention_dim]
|
||||
p = x[...,3*attention_dim:]
|
||||
|
||||
|
||||
k = self.whiten_keys(k) # does nothing in the forward pass.
|
||||
v = self.whiten_values(v) # does nothing in the forward pass.
|
||||
@ -1091,9 +1063,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
)
|
||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||
|
||||
q = (q * scaling).contiguous().view(seq_len, bsz, num_heads, head_dim)
|
||||
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
|
||||
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
q = q.reshape(seq_len, bsz, num_heads, head_dim)
|
||||
p = p.reshape(seq_len, bsz, num_heads, head_dim // 2)
|
||||
k = k.reshape(seq_len, bsz, num_heads, head_dim)
|
||||
v = v.reshape(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
|
||||
@ -1103,14 +1077,33 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
key_padding_mask.size(1), seq_len
|
||||
)
|
||||
|
||||
q = q.permute(1, 2, 0, 3) # (batch head, time1, head_dim)
|
||||
|
||||
|
||||
q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim)
|
||||
p = p.permute(1, 2, 0, 3) # (batch, head, time1, head_dim // 2)
|
||||
# compute attention score
|
||||
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
||||
|
||||
# pos_bias: (batch, head, time1, time2)
|
||||
pos_bias = self.rel_shift(pos)
|
||||
|
||||
attn_output_weights = torch.matmul(q, k) + pos_bias
|
||||
T2 = 2 * seq_len - 1
|
||||
pos = pos.reshape(1, T2, num_heads, head_dim // 2).permute(0, 2, 3, 1)
|
||||
# pos shape now: (batch, head, head_dim//2, T2)
|
||||
|
||||
# (batch, head, time1, head_dim // 2) x (1, head, head_dim//2, T2) -> (batch, head, time1, T2)
|
||||
# [where T2 represents relative position.]
|
||||
pos_weights = torch.matmul(p, pos)
|
||||
# the following .as_strided() expression converts the last axis of pos_weights from relative
|
||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||
# not, but let this code define which way round it is supposed to be.
|
||||
pos_weights = pos_weights.as_strided((bsz, num_heads, seq_len, seq_len),
|
||||
(pos_weights.stride(0),
|
||||
pos_weights.stride(1),
|
||||
pos_weights.stride(2)-pos_weights.stride(3),
|
||||
pos_weights.stride(3)),
|
||||
storage_offset=pos_weights.stride(3) * (seq_len - 1))
|
||||
|
||||
|
||||
attn_output_weights = torch.matmul(q, k) + pos_weights
|
||||
# attn_output_weights: (batch, head, time1, time2)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
@ -1180,7 +1173,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
# v: (tgt_len, bsz, embed_dim // 2)
|
||||
v = self.in_proj2(x)
|
||||
v = self.whiten_values2(v) # does nothing in the forward pass.
|
||||
v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
v = v.reshape(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
# now v: (bsz * num_heads, seq_len, head_dim)
|
||||
attn_output = torch.bmm(attn_weights, v)
|
||||
@ -1450,7 +1443,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
x = self.conv(x)
|
||||
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
x = self.out(x.transpose(1, 2).reshape(b, t, c * f))
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
@ -120,7 +120,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--attention-dims",
|
||||
type=str,
|
||||
default="256,256",
|
||||
default="192,192",
|
||||
help="Attention dimension in the 2 blocks of conformer encoder layers, comma separated"
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user