diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py index 41d8a9744..41f5f8b95 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py @@ -15,7 +15,142 @@ from scaling import ( softmax, ) from torch import Tensor, nn -from zipformer import CompactRelPositionalEncoding, SelfAttention, _whitening_schedule +from zipformer import CompactRelPositionalEncoding, _whitening_schedule +from icefall.utils import make_pad_mask + + +class CrossAttention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim: the input and output embedding dimension + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) + + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + Returns: + a tensor with the same shape as x. + """ + (am_seq_len, batch_size, embed_dim) = x.shape + (_, _, lm_seq_len, _) = attn_weights.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, lm_seq_len, am_seq_len) + + x = self.in_proj(x) # (am_seq_len, batch_size, num_heads * value_head_dim) + # print("projected x.shape", x.shape) + + x = x.reshape(am_seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, am_seq_len, value_head_dim) + # print("permuted x.shape", x.shape) + + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, lm_seq_len, value_head_dim) + # print("attended x.shape", x.shape) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(lm_seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (lm_seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.whiten(x) + # print("returned x.shape", x.shape) + + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + cached_val: cached attention value tensor of left context, + of shape (left_context_len, batch_size, value_dim) + left_context_len: number of left context frames. + + Returns: + - attention weighted output, a tensor with the same shape as x. + - updated cached attention value tensor of left context. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + seq_len2 = seq_len + left_context_len + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + + # Pad cached left contexts + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape[0], + left_context_len, + ) + x = torch.cat([cached_val, x], dim=0) + # Update cached left contexts + cached_val = x[-left_context_len:, ...] + + x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + return x, cached_val class RelPositionMultiheadAttentionWeights(nn.Module): @@ -40,7 +175,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module): def __init__( self, - embed_dim: int = 512, + lm_embed_dim: int = 512, + am_embed_dim: int = 512, pos_dim: int = 192, num_heads: int = 5, query_head_dim: int = 32, @@ -49,7 +185,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module): pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), ) -> None: super().__init__() - self.embed_dim = embed_dim + self.lm_embed_dim = lm_embed_dim + self.am_embed_dim = am_embed_dim self.num_heads = num_heads self.query_head_dim = query_head_dim self.pos_head_dim = pos_head_dim @@ -67,10 +204,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # to be used with the ScaledAdam optimizer; with most other optimizers, # it would be necessary to apply the scaling factor in the forward function. self.in_lm_proj = ScaledLinear( - embed_dim, in_lm_dim, bias=True, initial_scale=query_head_dim**-0.25 + lm_embed_dim, in_lm_dim, bias=True, initial_scale=query_head_dim**-0.25 ) self.in_am_proj = ScaledLinear( - embed_dim, in_am_dim, bias=True, initial_scale=query_head_dim**-0.25 + am_embed_dim, in_am_dim, bias=True, initial_scale=query_head_dim**-0.25 ) self.whiten_keys = Whiten( @@ -117,10 +254,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module): ) -> Tensor: r""" Args: - lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim) - am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim) - pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) - key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions + lm_pruned: input of shape (lm_seq_len, batch_size * prune_range, decoder_embed_dim) + am_pruned: input of shape (am_seq_len, batch_size * prune_range, encoder_embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*lm_seq_len - 1, pos_dim) + key_padding_mask: a bool tensor of shape (batch_size * prune_range, am_seq_len). Positions that are True in this mask will be ignored as sources in the attention weighting. attn_mask: mask of shape (seq_len, seq_len) or (seq_len, batch_size * prune_range, batch_size * prune_range), @@ -137,38 +274,45 @@ class RelPositionMultiheadAttentionWeights(nn.Module): num_heads = self.num_heads ( - seq_len, + lm_seq_len, b_p_dim, _, - ) = lm_pruned.shape # actual dim: (seq_len, batch * prune_range, _) + ) = lm_pruned.shape # actual dim: (lm_seq_len, batch * prune_range, _) + ( + am_seq_len, + _, + _, + ) = am_pruned.shape query_dim = query_head_dim * num_heads # self-attention - q = lm_pruned[..., 0:query_dim] # (seq_len, batch * prune_range, query_dim) - k = am_pruned # (seq_len, batch * prune_range, query_dim) + q = lm_pruned[..., 0:query_dim] # (lm_seq_len, batch * prune_range, query_dim) + k = am_pruned # (am_seq_len, batch * prune_range, query_dim) # p is the position-encoding query p = lm_pruned[ ..., query_dim: - ] # (seq_len, batch * prune_range, pos_head_dim * num_heads) + ] # (lm_seq_len, batch * prune_range, pos_head_dim * num_heads) assert p.shape[-1] == num_heads * pos_head_dim q = self.copy_query(q) # for diagnostics only, does nothing. k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. p = self.copy_pos_query(p) # for diagnostics only, does nothing. - q = q.reshape(seq_len, b_p_dim, num_heads, query_head_dim) - p = p.reshape(seq_len, b_p_dim, num_heads, pos_head_dim) - k = k.reshape(seq_len, b_p_dim, num_heads, query_head_dim) + q = q.reshape(lm_seq_len, b_p_dim, num_heads, query_head_dim) + p = p.reshape(lm_seq_len, b_p_dim, num_heads, pos_head_dim) + k = k.reshape(am_seq_len, b_p_dim, num_heads, query_head_dim) - # time1 refers to target, time2 refers to source. + # time1 refers to target (query: lm), time2 refers to source (key: am). q = q.permute( 2, 1, 0, 3 - ) # (head, batch * prune_range, seq_len, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch * prune_range, seq_len, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch * prune_range, d_k, seq_len) + ) # (head, batch * prune_range, lm_seq_len, query_head_dim) + p = p.permute( + 2, 1, 0, 3 + ) # (head, batch * prune_range, lm_seq_len, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch * prune_range, d_k, am_seq_len) - attn_scores = torch.matmul(q, k) + attn_scores = torch.matmul(q, k) # (head, batch, lm_seq_len, am_seq_len) use_pos_scores = False if torch.jit.is_scripting() or torch.jit.is_tracing(): @@ -179,7 +323,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): if use_pos_scores: pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 + seq_len2 = 2 * lm_seq_len - 1 pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( 2, 0, 3, 1 ) @@ -194,24 +338,23 @@ class RelPositionMultiheadAttentionWeights(nn.Module): if torch.jit.is_tracing(): (num_heads, b_p_dim, time1, n) = pos_scores.shape rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(seq_len) + cols = torch.arange(lm_seq_len) rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1) indexes = rows + cols pos_scores = pos_scores.reshape(-1, n) pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, seq_len) + pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len) else: pos_scores = pos_scores.as_strided( - (num_heads, b_p_dim, seq_len, seq_len), + (num_heads, b_p_dim, lm_seq_len, lm_seq_len), ( pos_scores.stride(0), pos_scores.stride(1), pos_scores.stride(2) - pos_scores.stride(3), pos_scores.stride(3), ), - storage_offset=pos_scores.stride(3) * (seq_len - 1), + storage_offset=pos_scores.stride(3) * (lm_seq_len - 1), ) - attn_scores = attn_scores + pos_scores if torch.jit.is_scripting() or torch.jit.is_tracing(): @@ -232,8 +375,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): attn_scores = penalize_abs_values_gt( attn_scores, limit=25.0, penalty=1.0e-04, name=self.name ) - - assert attn_scores.shape == (num_heads, b_p_dim, seq_len, seq_len) + assert attn_scores.shape == (num_heads, b_p_dim, lm_seq_len, am_seq_len) if attn_mask is not None: assert attn_mask.dtype == torch.bool @@ -246,7 +388,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): if key_padding_mask is not None: assert key_padding_mask.shape == ( b_p_dim, - seq_len, + am_seq_len, ), key_padding_mask.shape attn_scores = attn_scores.masked_fill( key_padding_mask.unsqueeze(1), @@ -271,7 +413,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): return attn_weights def _print_attn_entropy(self, attn_weights: Tensor): - # attn_weights: (num_heads, batch_size, seq_len, seq_len) + # attn_weights: (num_heads, batch_size, lm_seq_len, am_seq_len) (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): @@ -307,7 +449,7 @@ class AlignmentAttentionModule(nn.Module): pos_head_dim=pos_head_dim, dropout=dropout, ) - self.cross_attn = SelfAttention( + self.cross_attn = CrossAttention( embed_dim=embed_dim, num_heads=num_heads, value_head_dim=value_head_dim, @@ -317,57 +459,80 @@ class AlignmentAttentionModule(nn.Module): ) def forward(self, am_pruned: Tensor, lm_pruned: Tensor) -> Tensor: - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - (batch_size, T, prune_range, encoder_dim) = am_pruned.shape - (batch_size, T, prune_range, decoder_dim) = lm_pruned.shape + if len(am_pruned.shape) == 4 and len(lm_pruned.shape) == 4: + # src_key_padding_mask = make_pad_mask(am_pruned_lens) - # am_pruned : [T, B * prune_range, encoder_dim] - # lm_pruned : [T, B * prune_range, decoder_dim] - merged_am_pruned = am_pruned.permute(1, 0, 2, 3).reshape( - T, batch_size * prune_range, encoder_dim - ) - merged_lm_pruned = lm_pruned.permute(1, 0, 2, 3).reshape( - T, batch_size * prune_range, decoder_dim - ) + # am_pruned : [B, am_T, prune_range, encoder_dim] + # lm_pruned : [B, lm_T, prune_range, decoder_dim] + (batch_size, am_T, prune_range, encoder_dim) = am_pruned.shape + (batch_size, lm_T, prune_range, decoder_dim) = lm_pruned.shape - pos_emb = self.pos_encode(merged_am_pruned) + # merged_am_pruned : [am_T, B * prune_range, encoder_dim] + # merged_lm_pruned : [lm_T, B * prune_range, decoder_dim] + merged_am_pruned = am_pruned.permute(1, 0, 2, 3).reshape( + am_T, batch_size * prune_range, encoder_dim + ) + merged_lm_pruned = lm_pruned.permute(1, 0, 2, 3).reshape( + lm_T, batch_size * prune_range, decoder_dim + ) + pos_emb = self.pos_encode(merged_lm_pruned) - attn_weights = self.cross_attn_weights( - merged_lm_pruned, merged_am_pruned, pos_emb - ) - label_level_am_representation = self.cross_attn(merged_am_pruned, attn_weights) - # (T, batch_size * prune_range, encoder_dim) + attn_weights = self.cross_attn_weights( + merged_lm_pruned, merged_am_pruned, pos_emb + ) + # (num_heads, b_p_dim, lm_seq_len, am_seq_len) + # print("attn_weights.shape", attn_weights.shape) + label_level_am_representation = self.cross_attn( + merged_am_pruned, attn_weights + ) + # print( + # "label_level_am_representation.shape", + # label_level_am_representation.shape, + # ) + # (lm_seq_len, batch_size * prune_range, encoder_dim) - return label_level_am_representation.reshape( - T, batch_size, prune_range, encoder_dim - ).permute(1, 0, 2, 3) + return label_level_am_representation.reshape( + lm_T, batch_size, prune_range, encoder_dim + ).permute(1, 0, 2, 3) + elif len(am_pruned.shape) == 3 and len(lm_pruned.shape) == 3: + # am_pruned : [am_T, B, encoder_dim] + # lm_pruned : [lm_T, B, decoder_dim] + (am_T, batch_size, encoder_dim) = am_pruned.shape + (lm_T, batch_size, decoder_dim) = lm_pruned.shape + + pos_emb = self.pos_encode(lm_pruned) + + attn_weights = self.cross_attn_weights(lm_pruned, am_pruned, pos_emb) + label_level_am_representation = self.cross_attn(am_pruned, attn_weights) + # (T, batch_size, encoder_dim) + + return label_level_am_representation + else: + raise NotImplementedError("Dim Error") if __name__ == "__main__": - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - # am_pruned = torch.rand(2, 100, 5, 512).transpose(1, 0).reshape(2 * 5, 100, 512) - # lm_pruned = torch.rand(2, 100, 5, 512).transpose(1, 0).reshape(2 * 5, 100, 512) + attn = AlignmentAttentionModule() - # # am_pruned : [B * prune_range, T, encoder_dim] - # # lm_pruned : [B * prune_range, T, decoder_dim] - - # pos_emb = torch.rand(1, 19, 192) + print("__main__ === for inference ===") + # am : [T, B, encoder_dim] + # lm : [1, B, decoder_dim] + am = torch.rand(100, 2, 512) + lm = torch.rand(1, 2, 512) + # q / K separate seq_len # weights = RelPositionMultiheadAttentionWeights() - # attn = SelfAttention(512, 5, 12) - - # attn_weights = weights(lm_pruned, am_pruned, pos_emb) + # attn = CrossAttention(512, 5, 12) + # attn_weights = weights(lm, am, pos_emb) # print("weights(am_pruned, lm_pruned, pos_emb).shape", attn_weights.shape) - # res = attn(am_pruned, attn_weights) - # print("res", res.shape) + # res = attn(am, attn_weights) + res = attn(am, lm) + print("__main__ res", res.shape) + print("__main__ === for training ===") # am_pruned : [B, T, prune_range, encoder_dim] # lm_pruned : [B, T, prune_range, decoder_dim] am_pruned = torch.rand(2, 100, 5, 512) lm_pruned = torch.rand(2, 100, 5, 512) - - attn = AlignmentAttentionModule() res = attn(am_pruned, lm_pruned) - print("res", res.shape) + print("__main__ res", res.shape) diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module_debug.py b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module_debug.py index 5a9a425cb..41f5f8b95 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module_debug.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module_debug.py @@ -15,7 +15,142 @@ from scaling import ( softmax, ) from torch import Tensor, nn -from zipformer import CompactRelPositionalEncoding, SelfAttention, _whitening_schedule +from zipformer import CompactRelPositionalEncoding, _whitening_schedule +from icefall.utils import make_pad_mask + + +class CrossAttention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim: the input and output embedding dimension + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) + + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + Returns: + a tensor with the same shape as x. + """ + (am_seq_len, batch_size, embed_dim) = x.shape + (_, _, lm_seq_len, _) = attn_weights.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, lm_seq_len, am_seq_len) + + x = self.in_proj(x) # (am_seq_len, batch_size, num_heads * value_head_dim) + # print("projected x.shape", x.shape) + + x = x.reshape(am_seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, am_seq_len, value_head_dim) + # print("permuted x.shape", x.shape) + + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, lm_seq_len, value_head_dim) + # print("attended x.shape", x.shape) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(lm_seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (lm_seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.whiten(x) + # print("returned x.shape", x.shape) + + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + cached_val: cached attention value tensor of left context, + of shape (left_context_len, batch_size, value_dim) + left_context_len: number of left context frames. + + Returns: + - attention weighted output, a tensor with the same shape as x. + - updated cached attention value tensor of left context. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + seq_len2 = seq_len + left_context_len + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + + # Pad cached left contexts + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape[0], + left_context_len, + ) + x = torch.cat([cached_val, x], dim=0) + # Update cached left contexts + cached_val = x[-left_context_len:, ...] + + x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + return x, cached_val class RelPositionMultiheadAttentionWeights(nn.Module): @@ -40,7 +175,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module): def __init__( self, - embed_dim: int = 512, + lm_embed_dim: int = 512, + am_embed_dim: int = 512, pos_dim: int = 192, num_heads: int = 5, query_head_dim: int = 32, @@ -49,7 +185,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module): pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), ) -> None: super().__init__() - self.embed_dim = embed_dim + self.lm_embed_dim = lm_embed_dim + self.am_embed_dim = am_embed_dim self.num_heads = num_heads self.query_head_dim = query_head_dim self.pos_head_dim = pos_head_dim @@ -67,10 +204,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # to be used with the ScaledAdam optimizer; with most other optimizers, # it would be necessary to apply the scaling factor in the forward function. self.in_lm_proj = ScaledLinear( - embed_dim, in_lm_dim, bias=True, initial_scale=query_head_dim**-0.25 + lm_embed_dim, in_lm_dim, bias=True, initial_scale=query_head_dim**-0.25 ) self.in_am_proj = ScaledLinear( - embed_dim, in_am_dim, bias=True, initial_scale=query_head_dim**-0.25 + am_embed_dim, in_am_dim, bias=True, initial_scale=query_head_dim**-0.25 ) self.whiten_keys = Whiten( @@ -117,10 +254,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module): ) -> Tensor: r""" Args: - lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim) - am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim) - pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) - key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions + lm_pruned: input of shape (lm_seq_len, batch_size * prune_range, decoder_embed_dim) + am_pruned: input of shape (am_seq_len, batch_size * prune_range, encoder_embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*lm_seq_len - 1, pos_dim) + key_padding_mask: a bool tensor of shape (batch_size * prune_range, am_seq_len). Positions that are True in this mask will be ignored as sources in the attention weighting. attn_mask: mask of shape (seq_len, seq_len) or (seq_len, batch_size * prune_range, batch_size * prune_range), @@ -135,52 +272,47 @@ class RelPositionMultiheadAttentionWeights(nn.Module): query_head_dim = self.query_head_dim pos_head_dim = self.pos_head_dim num_heads = self.num_heads - print( - "query_head_dim", - query_head_dim, - "pos_head_dim", - pos_head_dim, - "num_heads", - num_heads, - ) ( - seq_len, + lm_seq_len, b_p_dim, _, - ) = lm_pruned.shape # actual dim: (batch * prune_range, seq_len, _) + ) = lm_pruned.shape # actual dim: (lm_seq_len, batch * prune_range, _) + ( + am_seq_len, + _, + _, + ) = am_pruned.shape query_dim = query_head_dim * num_heads # self-attention - q = lm_pruned[..., 0:query_dim] # (batch * prune_range, seq_len, query_dim) - k = am_pruned # (batch * prune_range, seq_len, query_dim) + q = lm_pruned[..., 0:query_dim] # (lm_seq_len, batch * prune_range, query_dim) + k = am_pruned # (am_seq_len, batch * prune_range, query_dim) # p is the position-encoding query p = lm_pruned[ ..., query_dim: - ] # (batch * prune_range, seq_len, pos_head_dim * num_heads) + ] # (lm_seq_len, batch * prune_range, pos_head_dim * num_heads) assert p.shape[-1] == num_heads * pos_head_dim q = self.copy_query(q) # for diagnostics only, does nothing. k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. p = self.copy_pos_query(p) # for diagnostics only, does nothing. - q = q.reshape(seq_len, b_p_dim, num_heads, query_head_dim) - p = p.reshape(seq_len, b_p_dim, num_heads, pos_head_dim) - k = k.reshape(seq_len, b_p_dim, num_heads, query_head_dim) - print("q.shape after reshape", q.shape) - print("p.shape after reshape", p.shape) - print("k.shape after reshape", k.shape) + q = q.reshape(lm_seq_len, b_p_dim, num_heads, query_head_dim) + p = p.reshape(lm_seq_len, b_p_dim, num_heads, pos_head_dim) + k = k.reshape(am_seq_len, b_p_dim, num_heads, query_head_dim) - # time1 refers to target, time2 refers to source. + # time1 refers to target (query: lm), time2 refers to source (key: am). q = q.permute( 2, 1, 0, 3 - ) # (head, seq_len, batch * prune_range, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, seq_len, batch * prune_range, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, seq_len, d_k, batch * prune_range) + ) # (head, batch * prune_range, lm_seq_len, query_head_dim) + p = p.permute( + 2, 1, 0, 3 + ) # (head, batch * prune_range, lm_seq_len, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch * prune_range, d_k, am_seq_len) - attn_scores = torch.matmul(q, k) - print("attn_scores", attn_scores.shape) + attn_scores = torch.matmul(q, k) # (head, batch, lm_seq_len, am_seq_len) use_pos_scores = False if torch.jit.is_scripting() or torch.jit.is_tracing(): @@ -191,14 +323,11 @@ class RelPositionMultiheadAttentionWeights(nn.Module): if use_pos_scores: pos_emb = self.linear_pos(pos_emb) - print("pos_emb before proj", pos_emb.shape) - seq_len2 = 2 * seq_len - 1 + seq_len2 = 2 * lm_seq_len - 1 pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( 2, 0, 3, 1 ) # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - print("p", p.shape) - print("pos_emb after proj", pos_emb.shape) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) # [where seq_len2 represents relative position.] @@ -209,24 +338,23 @@ class RelPositionMultiheadAttentionWeights(nn.Module): if torch.jit.is_tracing(): (num_heads, b_p_dim, time1, n) = pos_scores.shape rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(seq_len) + cols = torch.arange(lm_seq_len) rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1) indexes = rows + cols pos_scores = pos_scores.reshape(-1, n) pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, seq_len) + pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len) else: pos_scores = pos_scores.as_strided( - (num_heads, b_p_dim, seq_len, seq_len), + (num_heads, b_p_dim, lm_seq_len, lm_seq_len), ( pos_scores.stride(0), pos_scores.stride(1), pos_scores.stride(2) - pos_scores.stride(3), pos_scores.stride(3), ), - storage_offset=pos_scores.stride(3) * (seq_len - 1), + storage_offset=pos_scores.stride(3) * (lm_seq_len - 1), ) - attn_scores = attn_scores + pos_scores if torch.jit.is_scripting() or torch.jit.is_tracing(): @@ -247,8 +375,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): attn_scores = penalize_abs_values_gt( attn_scores, limit=25.0, penalty=1.0e-04, name=self.name ) - - assert attn_scores.shape == (num_heads, b_p_dim, seq_len, seq_len) + assert attn_scores.shape == (num_heads, b_p_dim, lm_seq_len, am_seq_len) if attn_mask is not None: assert attn_mask.dtype == torch.bool @@ -261,7 +388,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): if key_padding_mask is not None: assert key_padding_mask.shape == ( b_p_dim, - seq_len, + am_seq_len, ), key_padding_mask.shape attn_scores = attn_scores.masked_fill( key_padding_mask.unsqueeze(1), @@ -286,7 +413,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): return attn_weights def _print_attn_entropy(self, attn_weights: Tensor): - # attn_weights: (num_heads, batch_size, seq_len, seq_len) + # attn_weights: (num_heads, batch_size, lm_seq_len, am_seq_len) (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): @@ -322,7 +449,7 @@ class AlignmentAttentionModule(nn.Module): pos_head_dim=pos_head_dim, dropout=dropout, ) - self.cross_attn = SelfAttention( + self.cross_attn = CrossAttention( embed_dim=embed_dim, num_heads=num_heads, value_head_dim=value_head_dim, @@ -332,52 +459,80 @@ class AlignmentAttentionModule(nn.Module): ) def forward(self, am_pruned: Tensor, lm_pruned: Tensor) -> Tensor: - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - (batch_size, T, prune_range, encoder_dim) = am_pruned.shape - (batch_size, T, prune_range, decoder_dim) = lm_pruned.shape + if len(am_pruned.shape) == 4 and len(lm_pruned.shape) == 4: + # src_key_padding_mask = make_pad_mask(am_pruned_lens) - # am_pruned : [B * prune_range, T, encoder_dim] - # lm_pruned : [B * prune_range, T, decoder_dim] - am_pruned = am_pruned.permute(1, 0, 2, 3).reshape( - T, batch_size * prune_range, encoder_dim - ) - lm_pruned = lm_pruned.permute(1, 0, 2, 3).reshape( - T, batch_size * prune_range, decoder_dim - ) + # am_pruned : [B, am_T, prune_range, encoder_dim] + # lm_pruned : [B, lm_T, prune_range, decoder_dim] + (batch_size, am_T, prune_range, encoder_dim) = am_pruned.shape + (batch_size, lm_T, prune_range, decoder_dim) = lm_pruned.shape - pos_emb = self.pos_encode(am_pruned) - print("input pos_emb.shape", pos_emb.shape) + # merged_am_pruned : [am_T, B * prune_range, encoder_dim] + # merged_lm_pruned : [lm_T, B * prune_range, decoder_dim] + merged_am_pruned = am_pruned.permute(1, 0, 2, 3).reshape( + am_T, batch_size * prune_range, encoder_dim + ) + merged_lm_pruned = lm_pruned.permute(1, 0, 2, 3).reshape( + lm_T, batch_size * prune_range, decoder_dim + ) + pos_emb = self.pos_encode(merged_lm_pruned) - attn_weights = self.cross_attn_weights(lm_pruned, am_pruned, pos_emb) - label_level_am_representation = self.cross_attn(am_pruned, attn_weights) - return label_level_am_representation + attn_weights = self.cross_attn_weights( + merged_lm_pruned, merged_am_pruned, pos_emb + ) + # (num_heads, b_p_dim, lm_seq_len, am_seq_len) + # print("attn_weights.shape", attn_weights.shape) + label_level_am_representation = self.cross_attn( + merged_am_pruned, attn_weights + ) + # print( + # "label_level_am_representation.shape", + # label_level_am_representation.shape, + # ) + # (lm_seq_len, batch_size * prune_range, encoder_dim) + + return label_level_am_representation.reshape( + lm_T, batch_size, prune_range, encoder_dim + ).permute(1, 0, 2, 3) + elif len(am_pruned.shape) == 3 and len(lm_pruned.shape) == 3: + # am_pruned : [am_T, B, encoder_dim] + # lm_pruned : [lm_T, B, decoder_dim] + (am_T, batch_size, encoder_dim) = am_pruned.shape + (lm_T, batch_size, decoder_dim) = lm_pruned.shape + + pos_emb = self.pos_encode(lm_pruned) + + attn_weights = self.cross_attn_weights(lm_pruned, am_pruned, pos_emb) + label_level_am_representation = self.cross_attn(am_pruned, attn_weights) + # (T, batch_size, encoder_dim) + + return label_level_am_representation + else: + raise NotImplementedError("Dim Error") if __name__ == "__main__": - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - # am_pruned = torch.rand(2, 100, 5, 512).transpose(1, 0).reshape(2 * 5, 100, 512) - # lm_pruned = torch.rand(2, 100, 5, 512).transpose(1, 0).reshape(2 * 5, 100, 512) + attn = AlignmentAttentionModule() - # # am_pruned : [B * prune_range, T, encoder_dim] - # # lm_pruned : [B * prune_range, T, decoder_dim] - - # pos_emb = torch.rand(1, 19, 192) + print("__main__ === for inference ===") + # am : [T, B, encoder_dim] + # lm : [1, B, decoder_dim] + am = torch.rand(100, 2, 512) + lm = torch.rand(1, 2, 512) + # q / K separate seq_len # weights = RelPositionMultiheadAttentionWeights() - # attn = SelfAttention(512, 5, 12) - - # attn_weights = weights(lm_pruned, am_pruned, pos_emb) + # attn = CrossAttention(512, 5, 12) + # attn_weights = weights(lm, am, pos_emb) # print("weights(am_pruned, lm_pruned, pos_emb).shape", attn_weights.shape) - # res = attn(am_pruned, attn_weights) - # print("res", res.shape) + # res = attn(am, attn_weights) + res = attn(am, lm) + print("__main__ res", res.shape) + print("__main__ === for training ===") # am_pruned : [B, T, prune_range, encoder_dim] # lm_pruned : [B, T, prune_range, decoder_dim] am_pruned = torch.rand(2, 100, 5, 512) lm_pruned = torch.rand(2, 100, 5, 512) - - attn = AlignmentAttentionModule() res = attn(am_pruned, lm_pruned) - print("res", res.shape) + print("__main__ res", res.shape) diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py b/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py index 1ab7a3662..10d584eb4 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py @@ -39,6 +39,7 @@ class Joiner(nn.Module): self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, + apply_attn: bool = True, project_input: bool = True, ) -> torch.Tensor: """ @@ -47,7 +48,9 @@ class Joiner(nn.Module): Output from the encoder. Its shape is (N, T, s_range, C). decoder_out: Output from the decoder. Its shape is (N, T, s_range, C). - project_input: + encoder_out_lens: + Encoder output lengths, of shape (N,). + project_input: If true, apply input projections encoder_proj and decoder_proj. If this is false, it is the user's responsibility to do this manually. @@ -59,7 +62,8 @@ class Joiner(nn.Module): decoder_out.shape, ) - encoder_out = self.label_level_am_attention(encoder_out, decoder_out) + if apply_attn: + encoder_out = self.label_level_am_attention(encoder_out, decoder_out) if project_input: logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)