diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index baa096334..fd9a7cd8b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -71,6 +71,7 @@ class Zipformer(EncoderInterface): num_encoder_layers: Tuple[int] = (12, 12), dropout: float = 0.1, cnn_module_kernels: Tuple[int] = (31, 31), + pos_dim: int = 4, warmup_batches: float = 4000.0, ) -> None: super(Zipformer, self).__init__() @@ -107,6 +108,7 @@ class Zipformer(EncoderInterface): feedforward_dim[i], dropout, cnn_module_kernels[i], + pos_dim, ) # For the segment of the warmup period, we let the Conv2dSubsampling @@ -263,13 +265,14 @@ class ZipformerEncoderLayer(nn.Module): feedforward_dim: int = 2048, dropout: float = 0.1, cnn_module_kernel: int = 31, + pos_dim: int = 4, ) -> None: super(ZipformerEncoderLayer, self).__init__() self.d_model = d_model self.self_attn = RelPositionMultiheadAttention( - d_model, attention_dim, nhead, dropout=0.0, + d_model, attention_dim, nhead, pos_dim, dropout=0.0, ) self.feed_forward1 = FeedforwardModule(d_model, @@ -912,6 +915,7 @@ class RelPositionMultiheadAttention(nn.Module): embed_dim: int, attention_dim: int, num_heads: int, + pos_dim: int, dropout: float = 0.0, ) -> None: super(RelPositionMultiheadAttention, self).__init__() @@ -920,6 +924,7 @@ class RelPositionMultiheadAttention(nn.Module): self.num_heads = num_heads self.dropout = dropout self.head_dim = attention_dim // num_heads + self.pos_dim = pos_dim assert self.head_dim % 2 == 0, self.head_dim assert ( self.head_dim * num_heads == attention_dim @@ -927,27 +932,31 @@ class RelPositionMultiheadAttention(nn.Module): # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5, dividing it between the query and key. - self.in_proj = ScaledLinear(embed_dim, 3 * attention_dim, bias=True, + in_proj_dim = (2 * attention_dim + # query, key + attention_dim // 2 + # value + pos_dim * num_heads) # positional encoding query + + self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25) - # self.whiten_values is applied on the values in forward() + # self.whiten_values is applied on the values in forward(); + # it just copies the keys but prevents low-rank distribution by modifying grads. self.whiten_values = Whiten(num_groups=num_heads, whitening_limit=2.0, prob=(0.025, 0.25), grad_scale=0.025) - # self.whiten_keys is applied on the keys in forward() self.whiten_keys = Whiten(num_groups=num_heads, whitening_limit=2.0, prob=(0.025, 0.25), grad_scale=0.025) - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(embed_dim, attention_dim // 2, bias=False, + self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05) - # the following are for diagnosics only, see --print-diagnostics option + # the following are for diagnosics only, see --print-diagnostics option. + # they only copy their inputs. self.copy_pos_query = Identity() self.copy_query = Identity() @@ -1014,8 +1023,6 @@ class RelPositionMultiheadAttention(nn.Module): self.linear_pos(pos_emb), self.attention_dim, self.num_heads, - self.in_proj.weight, - self.in_proj.bias, self.dropout, self.out_proj.weight, self.out_proj.bias, @@ -1028,12 +1035,10 @@ class RelPositionMultiheadAttention(nn.Module): def multi_head_attention_forward( self, - x: Tensor, + x_proj: Tensor, pos: Tensor, attention_dim: int, num_heads: int, - in_proj_weight: Tensor, - in_proj_bias: Tensor, dropout_p: float, out_proj_weight: Tensor, out_proj_bias: Tensor, @@ -1047,7 +1052,6 @@ class RelPositionMultiheadAttention(nn.Module): pos: head-specific biases arising from the positional embeddings. attention_dim: dimension inside attention mechanism num_heads: parallel attention heads. - in_proj_weight, in_proj_bias: input projection weight and bias. dropout_p: probability of an element to be zeroed. out_proj_weight, out_proj_bias: the output projection weight and bias. training: apply dropout if is ``True``. @@ -1082,17 +1086,23 @@ class RelPositionMultiheadAttention(nn.Module): H is the num-heads, S is the sequence length. """ - seq_len, bsz, _ = x.size() + seq_len, bsz, _ = x_proj.size() head_dim = attention_dim // num_heads + pos_dim = self.pos_dim # positional-encoding dim per head assert ( head_dim * num_heads == attention_dim ), "attention_dim must be divisible by num_heads" # self-attention - q, k, pv = x.chunk(3, dim=-1) - p, v = pv.chunk(2, dim=-1) + q = x_proj[...,0:attention_dim] + k = x_proj[...,attention_dim:2*attention_dim] + value_dim = attention_dim // 2 + v = x_proj[...,2*attention_dim:2*attention_dim+value_dim] + # p is the position-encoding query, its dimension is num_heads*pos_dim.. + p = x_proj[...,2*attention_dim+value_dim:] + k = self.whiten_keys(k) # does nothing in the forward pass. v = self.whiten_values(v) # does nothing in the forward pass. @@ -1150,7 +1160,7 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask = key_padding_mask.to(torch.bool) q = q.reshape(seq_len, bsz, num_heads, head_dim) - p = p.reshape(seq_len, bsz, num_heads, head_dim // 2) + p = p.reshape(seq_len, bsz, num_heads, pos_dim) k = k.reshape(seq_len, bsz, num_heads, head_dim) v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) @@ -1166,16 +1176,16 @@ class RelPositionMultiheadAttention(nn.Module): q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) - p = p.permute(1, 2, 0, 3) # (batch, head, time1, head_dim // 2) + p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - T2 = 2 * seq_len - 1 - pos = pos.reshape(1, T2, num_heads, head_dim // 2).permute(0, 2, 3, 1) - # pos shape now: (batch, head, head_dim//2, T2) + seq_len2 = 2 * seq_len - 1 + pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) + # pos shape now: (batch, head, pos_dim, seq_len2) - # (batch, head, time1, head_dim // 2) x (1, head, head_dim//2, T2) -> (batch, head, time1, T2) - # [where T2 represents relative position.] + # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) + # [where seq_len2 represents relative position.] pos_weights = torch.matmul(p, pos) # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or @@ -1243,7 +1253,8 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] + assert list(attn_output.size()) == [bsz * num_heads, seq_len, + head_dim // 2] attn_output = ( attn_output.transpose(0, 1) .contiguous()