From 3f495cd1972ac63a4919ecd64f329591a1f4a099 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 17 Oct 2022 19:28:28 +0800 Subject: [PATCH] Reduce attention_dim to 192; cherry-pick scaled_adam_exp130 which is linear_pos interacting with query --- .../pruned_transducer_stateless7/conformer.py | 109 ++++++++---------- .../ASR/pruned_transducer_stateless7/train.py | 2 +- 2 files changed, 52 insertions(+), 59 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 61e1edc5a..951fb863d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -836,11 +836,15 @@ class RelPositionMultiheadAttention(nn.Module): self.num_heads = num_heads self.dropout = dropout self.head_dim = attention_dim // num_heads + assert self.head_dim % 2 == 0, self.head_dim assert ( self.head_dim * num_heads == attention_dim - ), "embed_dim//2 must be divisible by num_heads" + ) - self.in_proj = nn.Linear(embed_dim, 3 * attention_dim, bias=True) + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5, dividing it between the query and key. + self.in_proj = ScaledLinear(embed_dim, 7 * attention_dim // 2, bias=True, + initial_scale=self.head_dim**-0.25) # self.whiten_values is applied on the values in forward() self.whiten_values = Whiten(num_groups=num_heads, @@ -854,7 +858,12 @@ class RelPositionMultiheadAttention(nn.Module): grad_scale=0.025) - self.in_balancer = ActivationBalancer(3 * attention_dim, + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear(embed_dim, attention_dim // 2, bias=False, + initial_scale=0.05) + + self.in_balancer = ActivationBalancer(7 * attention_dim // 2, channel_dim=-1, max_abs=5.0) self.out_proj = ScaledLinear( attention_dim, embed_dim, bias=True, initial_scale=0.05 @@ -869,12 +878,6 @@ class RelPositionMultiheadAttention(nn.Module): prob=(0.025, 0.25), grad_scale=0.025) - # linear transformation for positional encoding (projects to a scalar per head, - # which will be added to the score). - self.linear_pos = ScaledLinear(embed_dim, num_heads, initial_scale=0.05) - - # linear transformation for positional encoding. - self.linear_pos = nn.Linear(embed_dim, num_heads, bias=False) def forward( self, @@ -936,39 +939,6 @@ class RelPositionMultiheadAttention(nn.Module): ) return x, weights - def rel_shift(self, pos_bias: Tensor) -> Tensor: - """Convert relative positional bias from linear to matrix format. - - Args: - pos_bias: Input tensor (1, 2*T-1, num_heads), where T is the number of frames. - - Returns: - Tensor of shape (1, num_heads, time1, time2) - (note: time2 has the same value as time1, but it is for - the key, while time1 is for the query). - """ - (batch_size, n, num_heads) = pos_bias.shape - assert batch_size == 1 - T = (n + 1) // 2 - assert n == 2 * T - 1 - # The leading T dimension behaves like a batch dimension. - # It is only needed because PyTorch does not currently support - # negative strides. - pos_bias = pos_bias.expand(T, n, num_heads).contiguous() - - # Note: TorchScript requires explicit arg for stride() - batch_stride = pos_bias.stride(0) - time_stride = pos_bias.stride(1) - head_stride = pos_bias.stride(2) - - # We could have left the batch dim as 1, and used '-time_stride' below - # where we use 'batch_stride - time_stride', but PyTorch does not support negative - # strides. - return pos_bias.as_strided( - (1, num_heads, T, T), - (0, head_stride, batch_stride - time_stride, time_stride), - storage_offset=time_stride * (T - 1), - ) def multi_head_attention_forward( self, @@ -1003,10 +973,10 @@ class RelPositionMultiheadAttention(nn.Module): Shape: Inputs: - - x: :math:`(L, N, 3 * A)` where L is the target sequence length, N is the batch size, A is - the attention dimension. Will be split into (query, key, value). - - pos: :math:`(N, 2*L-1, H)` or :math:`(1, 2*L-1, H)` where L is the sequence - length, N is the batch size, and H is the number of heads. + - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is + the attention dimension. Will be split into (query, key, value, pos). + - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence + length, N is the batch size, and A is the attention dim. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions will be unchanged. If a BoolTensor is provided, the positions with the @@ -1033,11 +1003,13 @@ class RelPositionMultiheadAttention(nn.Module): head_dim * num_heads == attention_dim ), "attention_dim must be divisible by num_heads" - scaling = float(head_dim) ** -0.5 - # self-attention - q, k, v = x.chunk(3, dim=-1) + q = x[...,:attention_dim] + k = x[...,attention_dim:2*attention_dim] + v = x[...,2*attention_dim:3*attention_dim] + p = x[...,3*attention_dim:] + k = self.whiten_keys(k) # does nothing in the forward pass. v = self.whiten_values(v) # does nothing in the forward pass. @@ -1091,9 +1063,11 @@ class RelPositionMultiheadAttention(nn.Module): ) key_padding_mask = key_padding_mask.to(torch.bool) - q = (q * scaling).contiguous().view(seq_len, bsz, num_heads, head_dim) - k = k.contiguous().view(-1, bsz, num_heads, head_dim) - v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + q = q.reshape(seq_len, bsz, num_heads, head_dim) + p = p.reshape(seq_len, bsz, num_heads, head_dim // 2) + k = k.reshape(seq_len, bsz, num_heads, head_dim) + v = v.reshape(seq_len, bsz * num_heads, head_dim).transpose(0, 1) + if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz, "{} == {}".format( @@ -1103,14 +1077,33 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask.size(1), seq_len ) - q = q.permute(1, 2, 0, 3) # (batch head, time1, head_dim) + + + q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) + p = p.permute(1, 2, 0, 3) # (batch, head, time1, head_dim // 2) # compute attention score k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - # pos_bias: (batch, head, time1, time2) - pos_bias = self.rel_shift(pos) - attn_output_weights = torch.matmul(q, k) + pos_bias + T2 = 2 * seq_len - 1 + pos = pos.reshape(1, T2, num_heads, head_dim // 2).permute(0, 2, 3, 1) + # pos shape now: (batch, head, head_dim//2, T2) + + # (batch, head, time1, head_dim // 2) x (1, head, head_dim//2, T2) -> (batch, head, time1, T2) + # [where T2 represents relative position.] + pos_weights = torch.matmul(p, pos) + # the following .as_strided() expression converts the last axis of pos_weights from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + pos_weights = pos_weights.as_strided((bsz, num_heads, seq_len, seq_len), + (pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2)-pos_weights.stride(3), + pos_weights.stride(3)), + storage_offset=pos_weights.stride(3) * (seq_len - 1)) + + + attn_output_weights = torch.matmul(q, k) + pos_weights # attn_output_weights: (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( @@ -1180,7 +1173,7 @@ class RelPositionMultiheadAttention(nn.Module): # v: (tgt_len, bsz, embed_dim // 2) v = self.in_proj2(x) v = self.whiten_values2(v) # does nothing in the forward pass. - v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1) + v = v.reshape(seq_len, bsz * num_heads, head_dim).transpose(0, 1) # now v: (bsz * num_heads, seq_len, head_dim) attn_output = torch.bmm(attn_weights, v) @@ -1450,7 +1443,7 @@ class Conv2dSubsampling(nn.Module): x = self.conv(x) # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x = self.out(x.transpose(1, 2).reshape(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) x = self.dropout(x) return x diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 8bdf9e40d..18d854279 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -120,7 +120,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--attention-dims", type=str, - default="256,256", + default="192,192", help="Attention dimension in the 2 blocks of conformer encoder layers, comma separated" )