diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 625651d3c..10527a7a5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -263,7 +263,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=dropout, + d_model, nhead, dropout=0.0, ) self.feed_forward1 = FeedforwardModule(d_model, @@ -436,16 +436,14 @@ class ConformerEncoder(nn.Module): def get_layers_to_drop(self, rnd_seed: int, warmup_count: float): num_layers = len(self.layers) - warmup_begin = self.warmup_begin - warmup_end = self.warmup_end def get_layerdrop_prob(layer: int) -> float: - layer_warmup_delta = (warmup_end - warmup_begin) / num_layers - layer_warmup_begin = warmup_begin + layer * layer_warmup_delta + layer_warmup_begin = self.layers[layer].warmup_begin + layer_warmup_end = self.layers[layer].warmup_end + initial_layerdrop_prob = 0.5 final_layerdrop_prob = 0.05 - layer_warmup_end = layer_warmup_begin + layer_warmup_delta if warmup_count < layer_warmup_begin: return initial_layerdrop_prob elif warmup_count > layer_warmup_end: @@ -483,7 +481,7 @@ class ConformerEncoder(nn.Module): if len(ans) == num_to_drop: break if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info(f"warmup_begin={warmup_begin:.1f}, warmup_end={warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") + logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") return ans @@ -839,9 +837,6 @@ class RelPositionMultiheadAttention(nn.Module): channel_dim=-1, max_abs=5.0) self.in_max_eig = MaxEig(3 * embed_dim // 2, channel_dim=-1) - self.proj_balancer = ActivationBalancer(embed_dim // 2, - channel_dim=-1, max_abs=10.0, - min_positive=0.0, max_positive=1.0) self.out_proj = ScaledLinear( embed_dim // 2, embed_dim, bias=True, initial_scale=0.05 ) @@ -850,18 +845,9 @@ class RelPositionMultiheadAttention(nn.Module): self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True, initial_scale=0.05) - - # linear transformation for positional encoding. - self.linear_pos = nn.Linear(embed_dim, embed_dim // 2, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self._reset_parameters() - - def _reset_parameters(self) -> None: - nn.init.uniform_(self.pos_bias_u, -0.05, 0.05) - nn.init.uniform_(self.pos_bias_v, -0.05, 0.05) + # linear transformation for positional encoding (projects to a scalar per head, + # which will be added to the score). + self.linear_pos = ScaledLinear(embed_dim, num_heads, initial_scale=0.05) def forward( self, @@ -909,7 +895,7 @@ class RelPositionMultiheadAttention(nn.Module): """ x, weights = self.multi_head_attention_forward( self.in_max_eig(self.in_balancer(self.in_proj(x))), - pos_emb, + self.linear_pos(pos_emb), self.embed_dim, self.num_heads, self.in_proj.weight, @@ -923,35 +909,44 @@ class RelPositionMultiheadAttention(nn.Module): ) return x, weights - def rel_shift(self, x: Tensor) -> Tensor: - """Compute relative positional encoding. + def rel_shift(self, pos_bias: Tensor) -> Tensor: + """Convert relative positional bias from linear to matrix format. Args: - x: Input tensor (batch, head, time1, 2*time1-1). - time1 means the length of query vector. + pos_bias: Input tensor (1, 2*T-1, num_heads), where T is the number of frames. Returns: - Tensor: tensor of shape (batch, head, time1, time2) - (note: time2 has the same value as time1, but it is for - the key, while time1 is for the query). + Tensor of shape (1, num_heads, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). """ - (batch_size, num_heads, time1, n) = x.shape - assert n == 2 * time1 - 1 + (batch_size, n, num_heads) = pos_bias.shape + assert batch_size == 1 + T = (n + 1) // 2 + assert n == 2 * T - 1 + # The leading T dimension behaves like a batch dimension. + # It is only needed because PyTorch does not currently support + # negative strides. + pos_bias = pos_bias.expand(T, n, num_heads).contiguous() + # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time1), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), + batch_stride = pos_bias.stride(0) + time_stride = pos_bias.stride(1) + head_stride = pos_bias.stride(2) + + # We could have left the batch dim as 1, and used '-time_stride' below + # where we use 'batch_stride - time_stride', but PyTorch does not support negative + # strides. + return pos_bias.as_strided( + (1, num_heads, T, T), + (0, head_stride, batch_stride - time_stride, time_stride), + storage_offset=time_stride * (T - 1), ) def multi_head_attention_forward( self, x: Tensor, - pos_emb: Tensor, + pos: Tensor, embed_dim: int, num_heads: int, in_proj_weight: Tensor, @@ -965,8 +960,8 @@ class RelPositionMultiheadAttention(nn.Module): ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor + x_proj: the projected input, to be split into query, key, value. + pos: head-specific biases arising from the positional embeddings. embed_dim: total dimension of the model. num_heads: parallel attention heads. in_proj_weight, in_proj_bias: input projection weight and bias. @@ -981,14 +976,10 @@ class RelPositionMultiheadAttention(nn.Module): Shape: Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence - length, N is the batch size, E is the embedding dimension. + - x: :math:`(L, N, 3 * E//2)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. Will be split into (query, key, value). + - pos: :math:`(N, 2*L-1, H)` or :math:`(1, 2*L-1, H)` where L is the sequence + length, N is the batch size, and H is the number of heads. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions will be unchanged. If a BoolTensor is provided, the positions with the @@ -1008,7 +999,7 @@ class RelPositionMultiheadAttention(nn.Module): H is the num-heads, S is the sequence length. """ - tgt_len, bsz, _ = x.size() + seq_len, bsz, _ = x.size() head_dim = embed_dim // (num_heads * 2) assert ( @@ -1040,15 +1031,15 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, tgt_len, tgt_len]: + if list(attn_mask.size()) != [1, seq_len, seq_len]: raise RuntimeError( "The size of the 2D attn_mask is not correct." ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, - tgt_len, - tgt_len, + seq_len, + seq_len, ]: raise RuntimeError( "The size of the 3D attn_mask is not correct." @@ -1071,63 +1062,37 @@ class RelPositionMultiheadAttention(nn.Module): ) key_padding_mask = key_padding_mask.to(torch.bool) - q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) + q = (q * scaling).contiguous().view(seq_len, bsz, num_heads, head_dim) k = k.contiguous().view(-1, bsz, num_heads, head_dim) v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) - src_len = k.size(0) if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz, "{} == {}".format( key_padding_mask.size(0), bsz ) - assert key_padding_mask.size(1) == src_len, "{} == {}".format( - key_padding_mask.size(1), src_len + assert key_padding_mask.size(1) == seq_len, "{} == {}".format( + key_padding_mask.size(1), seq_len ) - q = q.transpose(0, 1) # (batch, time1, head, d_k) - - pos_emb_bsz = pos_emb.size(0) - assert pos_emb_bsz in (1, bsz) # actually it is 1 - p = self.proj_balancer(self.linear_pos(pos_emb)).view(pos_emb_bsz, -1, num_heads, head_dim) - p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) - - q_with_bias_u = (q + self.pos_bias_u).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - q_with_bias_v = (q + self.pos_bias_v).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - + q = q.permute(1, 2, 0, 3) # (batch head, time1, head_dim) # compute attention score - # first compute matrix a and matrix c - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) - - # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p.transpose(-2, -1) - ) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd) - - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + # pos_bias: (batch, head, time1, time2) + pos_bias = self.rel_shift(pos) + attn_output_weights = torch.matmul(q, k) + pos_bias + # attn_output_weights: (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 + bsz * num_heads, seq_len, seq_len ) assert list(attn_output_weights.size()) == [ bsz * num_heads, - tgt_len, - src_len, + seq_len, + seq_len, ] if attn_mask is not None: @@ -1138,14 +1103,14 @@ class RelPositionMultiheadAttention(nn.Module): if key_padding_mask is not None: attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len + bsz, num_heads, seq_len, seq_len ) attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"), ) attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, src_len + bsz * num_heads, seq_len, seq_len ) attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) @@ -1154,11 +1119,11 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim] attn_output = ( attn_output.transpose(0, 1) .contiguous() - .view(tgt_len, bsz, embed_dim // 2) + .view(seq_len, bsz, embed_dim // 2) ) attn_output = nn.functional.linear( attn_output, out_proj_weight, out_proj_bias