From 7e5c7e6f77fe27e7a73888b8dc3c0fa20e0a92a1 Mon Sep 17 00:00:00 2001 From: zr_jin <60612200+JinZr@users.noreply.github.com> Date: Sun, 23 Jul 2023 20:19:19 +0800 Subject: [PATCH] fixes on feat dim --- .../alignment_attention_module.py | 80 ++++++++++--------- 1 file changed, 43 insertions(+), 37 deletions(-) diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py index 93490e923..e287c363c 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py @@ -117,12 +117,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module): ) -> Tensor: r""" Args: - lm_pruned: input of shape (batch_size * prune_range, seq_len, decoder_embed_dim) - am_pruned: input of shape (batch_size * prune_range, seq_len, encoder_embed_dim) - pos_emb: Positional embedding tensor, of shape (1, 2 * batch_size * prune_range - 1, pos_dim) - key_padding_mask: a bool tensor of shape (seq_len, batch_size * prune_range). Positions that - are True in this mask will be ignored as sources in the attention weighting. - attn_mask: mask of shape (batch_size * prune_range, batch_size * prune_range) + lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim) + am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) + key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions + that are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (seq_len, batch_size * prune_range, batch_size * prune_range), interpreted as ([seq_len,] batch_size * prune_range, batch_size * prune_range) saying which positions are allowed to attend to which other positions. @@ -137,36 +137,36 @@ class RelPositionMultiheadAttentionWeights(nn.Module): num_heads = self.num_heads ( - b_p_dim, seq_len, + b_p_dim, _, - ) = lm_pruned.shape # actual dim: (batch * prune_range, seq_len, _) + ) = lm_pruned.shape # actual dim: (seq_len, batch * prune_range, _) query_dim = query_head_dim * num_heads # self-attention - q = lm_pruned[..., 0:query_dim] # (batch * prune_range, seq_len, query_dim) - k = am_pruned # (batch * prune_range, seq_len, query_dim) + q = lm_pruned[..., 0:query_dim] # (seq_len, batch * prune_range, query_dim) + k = am_pruned # (seq_len, batch * prune_range, query_dim) # p is the position-encoding query p = lm_pruned[ ..., query_dim: - ] # (batch * prune_range, seq_len, pos_head_dim * num_heads) + ] # (seq_len, batch * prune_range, pos_head_dim * num_heads) assert p.shape[-1] == num_heads * pos_head_dim q = self.copy_query(q) # for diagnostics only, does nothing. k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. p = self.copy_pos_query(p) # for diagnostics only, does nothing. - q = q.reshape(b_p_dim, seq_len, num_heads, query_head_dim) - p = p.reshape(b_p_dim, seq_len, num_heads, pos_head_dim) - k = k.reshape(b_p_dim, seq_len, num_heads, query_head_dim) + q = q.reshape(seq_len, b_p_dim, num_heads, query_head_dim) + p = p.reshape(seq_len, b_p_dim, num_heads, pos_head_dim) + k = k.reshape(seq_len, b_p_dim, num_heads, query_head_dim) # time1 refers to target, time2 refers to source. q = q.permute( 2, 1, 0, 3 - ) # (head, seq_len, batch * prune_range, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, seq_len, batch * prune_range, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, seq_len, d_k, batch * prune_range) + ) # (head, batch * prune_range, seq_len, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch * prune_range, seq_len, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch * prune_range, d_k, seq_len) attn_scores = torch.matmul(q, k) @@ -179,12 +179,14 @@ class RelPositionMultiheadAttentionWeights(nn.Module): if use_pos_scores: pos_emb = self.linear_pos(pos_emb) - - seq_len2 = 2 * b_p_dim - 1 + print("pos_emb before proj", pos_emb.shape) + seq_len2 = 2 * seq_len - 1 pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( 2, 0, 3, 1 ) # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + print("p", p.shape) + print("pos_emb after proj", pos_emb.shape) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) # [where seq_len2 represents relative position.] @@ -193,24 +195,24 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. if torch.jit.is_tracing(): - (num_heads, seq_len, time1, n) = pos_scores.shape + (num_heads, b_p_dim, time1, n) = pos_scores.shape rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(b_p_dim) - rows = rows.repeat(seq_len * num_heads).unsqueeze(-1) + cols = torch.arange(seq_len) + rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1) indexes = rows + cols pos_scores = pos_scores.reshape(-1, n) pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - pos_scores = pos_scores.reshape(num_heads, seq_len, time1, b_p_dim) + pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, seq_len) else: pos_scores = pos_scores.as_strided( - (num_heads, seq_len, b_p_dim, b_p_dim), + (num_heads, b_p_dim, seq_len, seq_len), ( pos_scores.stride(0), pos_scores.stride(1), pos_scores.stride(2) - pos_scores.stride(3), pos_scores.stride(3), ), - storage_offset=pos_scores.stride(3) * (b_p_dim - 1), + storage_offset=pos_scores.stride(3) * (seq_len - 1), ) attn_scores = attn_scores + pos_scores @@ -234,7 +236,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): attn_scores, limit=25.0, penalty=1.0e-04, name=self.name ) - assert attn_scores.shape == (num_heads, seq_len, b_p_dim, b_p_dim) + assert attn_scores.shape == (num_heads, b_p_dim, seq_len, seq_len) if attn_mask is not None: assert attn_mask.dtype == torch.bool @@ -246,8 +248,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module): if key_padding_mask is not None: assert key_padding_mask.shape == ( - seq_len, b_p_dim, + seq_len, ), key_padding_mask.shape attn_scores = attn_scores.masked_fill( key_padding_mask.unsqueeze(1), @@ -323,20 +325,24 @@ class AlignmentAttentionModule(nn.Module): (batch_size, T, prune_range, encoder_dim) = am_pruned.shape (batch_size, T, prune_range, decoder_dim) = lm_pruned.shape - # am_pruned : [B * prune_range, T, encoder_dim] - # lm_pruned : [B * prune_range, T, decoder_dim] - am_pruned = am_pruned.transpose(1, 0).reshape( - batch_size * prune_range, T, encoder_dim + # am_pruned : [T, B * prune_range, encoder_dim] + # lm_pruned : [T, B * prune_range, decoder_dim] + merged_am_pruned = am_pruned.permute(1, 0, 2, 3).reshape( + T, batch_size * prune_range, encoder_dim ) - lm_pruned = lm_pruned.transpose(1, 0).reshape( - batch_size * prune_range, T, decoder_dim + merged_lm_pruned = lm_pruned.permute(1, 0, 2, 3).reshape( + T, batch_size * prune_range, decoder_dim ) - pos_emb = self.pos_encode(am_pruned) + pos_emb = self.pos_encode(merged_am_pruned) - attn_weights = self.cross_attn_weights(lm_pruned, am_pruned, pos_emb) - label_level_am_representation = self.cross_attn(am_pruned, attn_weights) - return label_level_am_representation.reshape(batch_size, T, prune_range, encoder_dim) + attn_weights = self.cross_attn_weights(merged_lm_pruned, merged_am_pruned, pos_emb) + label_level_am_representation = self.cross_attn(merged_am_pruned, attn_weights) + # (T, batch_size * prune_range, encoder_dim) + + return label_level_am_representation \ + .reshape(T, batch_size, prune_range, encoder_dim) \ + .permute(1, 0, 2, 3) if __name__ == "__main__":