mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Reduce dim of linear positional encoding in attention layers.
This commit is contained in:
parent
96ea4cf1be
commit
435d0dec71
@ -71,6 +71,7 @@ class Zipformer(EncoderInterface):
|
||||
num_encoder_layers: Tuple[int] = (12, 12),
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernels: Tuple[int] = (31, 31),
|
||||
pos_dim: int = 4,
|
||||
warmup_batches: float = 4000.0,
|
||||
) -> None:
|
||||
super(Zipformer, self).__init__()
|
||||
@ -107,6 +108,7 @@ class Zipformer(EncoderInterface):
|
||||
feedforward_dim[i],
|
||||
dropout,
|
||||
cnn_module_kernels[i],
|
||||
pos_dim,
|
||||
)
|
||||
|
||||
# For the segment of the warmup period, we let the Conv2dSubsampling
|
||||
@ -263,13 +265,14 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
feedforward_dim: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
pos_dim: int = 4,
|
||||
) -> None:
|
||||
super(ZipformerEncoderLayer, self).__init__()
|
||||
|
||||
self.d_model = d_model
|
||||
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
d_model, attention_dim, nhead, dropout=0.0,
|
||||
d_model, attention_dim, nhead, pos_dim, dropout=0.0,
|
||||
)
|
||||
|
||||
self.feed_forward1 = FeedforwardModule(d_model,
|
||||
@ -912,6 +915,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
embed_dim: int,
|
||||
attention_dim: int,
|
||||
num_heads: int,
|
||||
pos_dim: int,
|
||||
dropout: float = 0.0,
|
||||
) -> None:
|
||||
super(RelPositionMultiheadAttention, self).__init__()
|
||||
@ -920,6 +924,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = attention_dim // num_heads
|
||||
self.pos_dim = pos_dim
|
||||
assert self.head_dim % 2 == 0, self.head_dim
|
||||
assert (
|
||||
self.head_dim * num_heads == attention_dim
|
||||
@ -927,27 +932,31 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
# the initial_scale is supposed to take over the "scaling" factor of
|
||||
# head_dim ** -0.5, dividing it between the query and key.
|
||||
self.in_proj = ScaledLinear(embed_dim, 3 * attention_dim, bias=True,
|
||||
in_proj_dim = (2 * attention_dim + # query, key
|
||||
attention_dim // 2 + # value
|
||||
pos_dim * num_heads) # positional encoding query
|
||||
|
||||
self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True,
|
||||
initial_scale=self.head_dim**-0.25)
|
||||
|
||||
# self.whiten_values is applied on the values in forward()
|
||||
# self.whiten_values is applied on the values in forward();
|
||||
# it just copies the keys but prevents low-rank distribution by modifying grads.
|
||||
self.whiten_values = Whiten(num_groups=num_heads,
|
||||
whitening_limit=2.0,
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.025)
|
||||
# self.whiten_keys is applied on the keys in forward()
|
||||
self.whiten_keys = Whiten(num_groups=num_heads,
|
||||
whitening_limit=2.0,
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.025)
|
||||
|
||||
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = ScaledLinear(embed_dim, attention_dim // 2, bias=False,
|
||||
self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False,
|
||||
initial_scale=0.05)
|
||||
|
||||
# the following are for diagnosics only, see --print-diagnostics option
|
||||
# the following are for diagnosics only, see --print-diagnostics option.
|
||||
# they only copy their inputs.
|
||||
self.copy_pos_query = Identity()
|
||||
self.copy_query = Identity()
|
||||
|
||||
@ -1014,8 +1023,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self.linear_pos(pos_emb),
|
||||
self.attention_dim,
|
||||
self.num_heads,
|
||||
self.in_proj.weight,
|
||||
self.in_proj.bias,
|
||||
self.dropout,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
@ -1028,12 +1035,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
def multi_head_attention_forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
x_proj: Tensor,
|
||||
pos: Tensor,
|
||||
attention_dim: int,
|
||||
num_heads: int,
|
||||
in_proj_weight: Tensor,
|
||||
in_proj_bias: Tensor,
|
||||
dropout_p: float,
|
||||
out_proj_weight: Tensor,
|
||||
out_proj_bias: Tensor,
|
||||
@ -1047,7 +1052,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
pos: head-specific biases arising from the positional embeddings.
|
||||
attention_dim: dimension inside attention mechanism
|
||||
num_heads: parallel attention heads.
|
||||
in_proj_weight, in_proj_bias: input projection weight and bias.
|
||||
dropout_p: probability of an element to be zeroed.
|
||||
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
||||
training: apply dropout if is ``True``.
|
||||
@ -1082,17 +1086,23 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
H is the num-heads, S is the sequence length.
|
||||
"""
|
||||
|
||||
seq_len, bsz, _ = x.size()
|
||||
seq_len, bsz, _ = x_proj.size()
|
||||
|
||||
head_dim = attention_dim // num_heads
|
||||
pos_dim = self.pos_dim # positional-encoding dim per head
|
||||
assert (
|
||||
head_dim * num_heads == attention_dim
|
||||
), "attention_dim must be divisible by num_heads"
|
||||
|
||||
|
||||
# self-attention
|
||||
q, k, pv = x.chunk(3, dim=-1)
|
||||
p, v = pv.chunk(2, dim=-1)
|
||||
q = x_proj[...,0:attention_dim]
|
||||
k = x_proj[...,attention_dim:2*attention_dim]
|
||||
value_dim = attention_dim // 2
|
||||
v = x_proj[...,2*attention_dim:2*attention_dim+value_dim]
|
||||
# p is the position-encoding query, its dimension is num_heads*pos_dim..
|
||||
p = x_proj[...,2*attention_dim+value_dim:]
|
||||
|
||||
|
||||
k = self.whiten_keys(k) # does nothing in the forward pass.
|
||||
v = self.whiten_values(v) # does nothing in the forward pass.
|
||||
@ -1150,7 +1160,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||
|
||||
q = q.reshape(seq_len, bsz, num_heads, head_dim)
|
||||
p = p.reshape(seq_len, bsz, num_heads, head_dim // 2)
|
||||
p = p.reshape(seq_len, bsz, num_heads, pos_dim)
|
||||
k = k.reshape(seq_len, bsz, num_heads, head_dim)
|
||||
v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1)
|
||||
|
||||
@ -1166,16 +1176,16 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
|
||||
q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim)
|
||||
p = p.permute(1, 2, 0, 3) # (batch, head, time1, head_dim // 2)
|
||||
p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim)
|
||||
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
||||
|
||||
|
||||
T2 = 2 * seq_len - 1
|
||||
pos = pos.reshape(1, T2, num_heads, head_dim // 2).permute(0, 2, 3, 1)
|
||||
# pos shape now: (batch, head, head_dim//2, T2)
|
||||
seq_len2 = 2 * seq_len - 1
|
||||
pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1)
|
||||
# pos shape now: (batch, head, pos_dim, seq_len2)
|
||||
|
||||
# (batch, head, time1, head_dim // 2) x (1, head, head_dim//2, T2) -> (batch, head, time1, T2)
|
||||
# [where T2 represents relative position.]
|
||||
# (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2)
|
||||
# [where seq_len2 represents relative position.]
|
||||
pos_weights = torch.matmul(p, pos)
|
||||
# the following .as_strided() expression converts the last axis of pos_weights from relative
|
||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||
@ -1243,7 +1253,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2]
|
||||
assert list(attn_output.size()) == [bsz * num_heads, seq_len,
|
||||
head_dim // 2]
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user