Reduce dim of linear positional encoding in attention layers.

This commit is contained in:
Daniel Povey 2022-10-29 15:31:34 +08:00
parent 96ea4cf1be
commit 435d0dec71

View File

@ -71,6 +71,7 @@ class Zipformer(EncoderInterface):
num_encoder_layers: Tuple[int] = (12, 12),
dropout: float = 0.1,
cnn_module_kernels: Tuple[int] = (31, 31),
pos_dim: int = 4,
warmup_batches: float = 4000.0,
) -> None:
super(Zipformer, self).__init__()
@ -107,6 +108,7 @@ class Zipformer(EncoderInterface):
feedforward_dim[i],
dropout,
cnn_module_kernels[i],
pos_dim,
)
# For the segment of the warmup period, we let the Conv2dSubsampling
@ -263,13 +265,14 @@ class ZipformerEncoderLayer(nn.Module):
feedforward_dim: int = 2048,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
pos_dim: int = 4,
) -> None:
super(ZipformerEncoderLayer, self).__init__()
self.d_model = d_model
self.self_attn = RelPositionMultiheadAttention(
d_model, attention_dim, nhead, dropout=0.0,
d_model, attention_dim, nhead, pos_dim, dropout=0.0,
)
self.feed_forward1 = FeedforwardModule(d_model,
@ -912,6 +915,7 @@ class RelPositionMultiheadAttention(nn.Module):
embed_dim: int,
attention_dim: int,
num_heads: int,
pos_dim: int,
dropout: float = 0.0,
) -> None:
super(RelPositionMultiheadAttention, self).__init__()
@ -920,6 +924,7 @@ class RelPositionMultiheadAttention(nn.Module):
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = attention_dim // num_heads
self.pos_dim = pos_dim
assert self.head_dim % 2 == 0, self.head_dim
assert (
self.head_dim * num_heads == attention_dim
@ -927,27 +932,31 @@ class RelPositionMultiheadAttention(nn.Module):
# the initial_scale is supposed to take over the "scaling" factor of
# head_dim ** -0.5, dividing it between the query and key.
self.in_proj = ScaledLinear(embed_dim, 3 * attention_dim, bias=True,
in_proj_dim = (2 * attention_dim + # query, key
attention_dim // 2 + # value
pos_dim * num_heads) # positional encoding query
self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True,
initial_scale=self.head_dim**-0.25)
# self.whiten_values is applied on the values in forward()
# self.whiten_values is applied on the values in forward();
# it just copies the keys but prevents low-rank distribution by modifying grads.
self.whiten_values = Whiten(num_groups=num_heads,
whitening_limit=2.0,
prob=(0.025, 0.25),
grad_scale=0.025)
# self.whiten_keys is applied on the keys in forward()
self.whiten_keys = Whiten(num_groups=num_heads,
whitening_limit=2.0,
prob=(0.025, 0.25),
grad_scale=0.025)
# linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim, attention_dim // 2, bias=False,
self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False,
initial_scale=0.05)
# the following are for diagnosics only, see --print-diagnostics option
# the following are for diagnosics only, see --print-diagnostics option.
# they only copy their inputs.
self.copy_pos_query = Identity()
self.copy_query = Identity()
@ -1014,8 +1023,6 @@ class RelPositionMultiheadAttention(nn.Module):
self.linear_pos(pos_emb),
self.attention_dim,
self.num_heads,
self.in_proj.weight,
self.in_proj.bias,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
@ -1028,12 +1035,10 @@ class RelPositionMultiheadAttention(nn.Module):
def multi_head_attention_forward(
self,
x: Tensor,
x_proj: Tensor,
pos: Tensor,
attention_dim: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
@ -1047,7 +1052,6 @@ class RelPositionMultiheadAttention(nn.Module):
pos: head-specific biases arising from the positional embeddings.
attention_dim: dimension inside attention mechanism
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``.
@ -1082,17 +1086,23 @@ class RelPositionMultiheadAttention(nn.Module):
H is the num-heads, S is the sequence length.
"""
seq_len, bsz, _ = x.size()
seq_len, bsz, _ = x_proj.size()
head_dim = attention_dim // num_heads
pos_dim = self.pos_dim # positional-encoding dim per head
assert (
head_dim * num_heads == attention_dim
), "attention_dim must be divisible by num_heads"
# self-attention
q, k, pv = x.chunk(3, dim=-1)
p, v = pv.chunk(2, dim=-1)
q = x_proj[...,0:attention_dim]
k = x_proj[...,attention_dim:2*attention_dim]
value_dim = attention_dim // 2
v = x_proj[...,2*attention_dim:2*attention_dim+value_dim]
# p is the position-encoding query, its dimension is num_heads*pos_dim..
p = x_proj[...,2*attention_dim+value_dim:]
k = self.whiten_keys(k) # does nothing in the forward pass.
v = self.whiten_values(v) # does nothing in the forward pass.
@ -1150,7 +1160,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask = key_padding_mask.to(torch.bool)
q = q.reshape(seq_len, bsz, num_heads, head_dim)
p = p.reshape(seq_len, bsz, num_heads, head_dim // 2)
p = p.reshape(seq_len, bsz, num_heads, pos_dim)
k = k.reshape(seq_len, bsz, num_heads, head_dim)
v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1)
@ -1166,16 +1176,16 @@ class RelPositionMultiheadAttention(nn.Module):
q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim)
p = p.permute(1, 2, 0, 3) # (batch, head, time1, head_dim // 2)
p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim)
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
T2 = 2 * seq_len - 1
pos = pos.reshape(1, T2, num_heads, head_dim // 2).permute(0, 2, 3, 1)
# pos shape now: (batch, head, head_dim//2, T2)
seq_len2 = 2 * seq_len - 1
pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1)
# pos shape now: (batch, head, pos_dim, seq_len2)
# (batch, head, time1, head_dim // 2) x (1, head, head_dim//2, T2) -> (batch, head, time1, T2)
# [where T2 represents relative position.]
# (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2)
# [where seq_len2 represents relative position.]
pos_weights = torch.matmul(p, pos)
# the following .as_strided() expression converts the last axis of pos_weights from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or
@ -1243,7 +1253,8 @@ class RelPositionMultiheadAttention(nn.Module):
)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2]
assert list(attn_output.size()) == [bsz * num_heads, seq_len,
head_dim // 2]
attn_output = (
attn_output.transpose(0, 1)
.contiguous()