From 90cb5183989fc3fea4439df6dbd3deda22c1fd8c Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Tue, 25 Jul 2023 16:02:48 +0800 Subject: [PATCH] update --- .../alignment_attention_module.py | 178 +++++++++++------- .../ASR/zipformer_label_level_algn/joiner.py | 7 +- .../ASR/zipformer_label_level_algn/model.py | 4 +- 3 files changed, 115 insertions(+), 74 deletions(-) diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py index 41f5f8b95..b93f8a4da 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py @@ -4,6 +4,7 @@ import random from typing import Optional, Tuple import torch +import torch.nn.functional as F from scaling import ( Balancer, FloatLike, @@ -67,7 +68,12 @@ class CrossAttention(nn.Module): (am_seq_len, batch_size, embed_dim) = x.shape (_, _, lm_seq_len, _) = attn_weights.shape num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, lm_seq_len, am_seq_len) + assert attn_weights.shape == ( + num_heads, + batch_size, + lm_seq_len, + am_seq_len, + ), f"{attn_weights.shape}" x = self.in_proj(x) # (am_seq_len, batch_size, num_heads * value_head_dim) # print("projected x.shape", x.shape) @@ -181,6 +187,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): num_heads: int = 5, query_head_dim: int = 32, pos_head_dim: int = 4, + prune_range: int = 5, dropout: float = 0.0, pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), ) -> None: @@ -190,6 +197,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): self.num_heads = num_heads self.query_head_dim = query_head_dim self.pos_head_dim = pos_head_dim + self.prune_range = prune_range self.dropout = dropout self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) self.name = None # will be overwritten in training code; for diagnostics. @@ -201,7 +209,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5 that has been used in previous forms of attention, # dividing it between the query and key. Note: this module is intended - # to be used with the ScaledAdam optimizer; with most other optimizers, + # to be used with the ScaledAdam optimizer; with most other optimizers + # , # it would be necessary to apply the scaling factor in the forward function. self.in_lm_proj = ScaledLinear( lm_embed_dim, in_lm_dim, bias=True, initial_scale=query_head_dim**-0.25 @@ -294,6 +303,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module): ..., query_dim: ] # (lm_seq_len, batch * prune_range, pos_head_dim * num_heads) assert p.shape[-1] == num_heads * pos_head_dim + # print("q.shape", q.shape) + # print("p.shape", p.shape) + # print("k.shape", k.shape) q = self.copy_query(q) # for diagnostics only, does nothing. k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. @@ -303,7 +315,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module): p = p.reshape(lm_seq_len, b_p_dim, num_heads, pos_head_dim) k = k.reshape(am_seq_len, b_p_dim, num_heads, query_head_dim) - # time1 refers to target (query: lm), time2 refers to source (key: am). + # time1 refers to target (query: lm), tim + # e2 refers to source (key: am). q = q.permute( 2, 1, 0, 3 ) # (head, batch * prune_range, lm_seq_len, query_head_dim) @@ -314,48 +327,48 @@ class RelPositionMultiheadAttentionWeights(nn.Module): attn_scores = torch.matmul(q, k) # (head, batch, lm_seq_len, am_seq_len) - use_pos_scores = False - if torch.jit.is_scripting() or torch.jit.is_tracing(): - # We can't put random.random() in the same line - use_pos_scores = True - elif not self.training or random.random() >= float(self.pos_emb_skip_rate): - use_pos_scores = True + # use_pos_scores = False + # if torch.jit.is_scripting() or torch.jit.is_tracing(): + # # We can't put random.random() in the same line + # use_pos_scores = True + # elif not self.training or random.random() >= float(self.pos_emb_skip_rate): + # use_pos_scores = True - if use_pos_scores: - pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * lm_seq_len - 1 - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( - 2, 0, 3, 1 - ) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + # if use_pos_scores: + # pos_emb = self.linear_pos(pos_emb) + # seq_len2 = 2 * lm_seq_len - 1 + # pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + # 2, 0, 3, 1 + # ) + # # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - if torch.jit.is_tracing(): - (num_heads, b_p_dim, time1, n) = pos_scores.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(lm_seq_len) - rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_scores = pos_scores.reshape(-1, n) - pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len) - else: - pos_scores = pos_scores.as_strided( - (num_heads, b_p_dim, lm_seq_len, lm_seq_len), - ( - pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2) - pos_scores.stride(3), - pos_scores.stride(3), - ), - storage_offset=pos_scores.stride(3) * (lm_seq_len - 1), - ) - attn_scores = attn_scores + pos_scores + # # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # # [where seq_len2 represents relative position.] + # pos_scores = torch.matmul(p, pos_emb) + # # the following .as_strided() expression converts the last axis of pos_scores from relative + # # to absolute position. I don't know whether I might have got the time-offsets backwards or + # # not, but let this code define which way round it is supposed to be. + # if torch.jit.is_tracing(): + # (num_heads, b_p_dim, time1, n) = pos_scores.shape + # rows = torch.arange(start=time1 - 1, end=-1, step=-1) + # cols = torch.arange(lm_seq_len) + # rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1) + # indexes = rows + cols + # pos_scores = pos_scores.reshape(-1, n) + # pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + # pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len) + # else: + # pos_scores = pos_scores.as_strided( + # (num_heads, b_p_dim, lm_seq_len, lm_seq_len), + # ( + # pos_scores.stride(0), + # pos_scores.stride(1), + # pos_scores.stride(2) - pos_scores.stride(3), + # pos_scores.stride(3), + # ), + # storage_offset=pos_scores.stride(3) * (lm_seq_len - 1), + # ) + # attn_scores = attn_scores + pos_scores if torch.jit.is_scripting() or torch.jit.is_tracing(): pass @@ -375,7 +388,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module): attn_scores = penalize_abs_values_gt( attn_scores, limit=25.0, penalty=1.0e-04, name=self.name ) - assert attn_scores.shape == (num_heads, b_p_dim, lm_seq_len, am_seq_len) + assert attn_scores.shape == ( + num_heads, + b_p_dim, + lm_seq_len, + am_seq_len, + ), attn_scores.shape if attn_mask is not None: assert attn_mask.dtype == torch.bool @@ -386,12 +404,20 @@ class RelPositionMultiheadAttentionWeights(nn.Module): attn_scores = attn_scores.masked_fill(attn_mask, -1000) if key_padding_mask is not None: - assert key_padding_mask.shape == ( - b_p_dim, - am_seq_len, - ), key_padding_mask.shape + # (batch, max_len) + + key_padding_mask = ( + ( + key_padding_mask.unsqueeze(0) + .repeat(1, self.prune_range, 1) + .unsqueeze(2) + ) + if key_padding_mask.shape[0] != attn_scores.shape[1] + else key_padding_mask.unsqueeze(0).unsqueeze(2) + ) + attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), + key_padding_mask, -1000, ) @@ -438,6 +464,7 @@ class AlignmentAttentionModule(nn.Module): query_head_dim: int = 32, value_head_dim: int = 12, pos_head_dim: int = 4, + prune_range: int = 5, dropout: float = 0.0, ): super().__init__() @@ -458,10 +485,13 @@ class AlignmentAttentionModule(nn.Module): embed_dim=pos_dim, dropout_rate=0.15 ) - def forward(self, am_pruned: Tensor, lm_pruned: Tensor) -> Tensor: - if len(am_pruned.shape) == 4 and len(lm_pruned.shape) == 4: - # src_key_padding_mask = make_pad_mask(am_pruned_lens) + def forward( + self, am_pruned: Tensor, lm_pruned: Tensor, lengths: torch.Tensor + ) -> Tensor: + src_key_padding_mask = make_pad_mask(lengths) + # (batch, max_len) + if am_pruned.ndim == 4 and lm_pruned.ndim == 4: # am_pruned : [B, am_T, prune_range, encoder_dim] # lm_pruned : [B, lm_T, prune_range, decoder_dim] (batch_size, am_T, prune_range, encoder_dim) = am_pruned.shape @@ -478,35 +508,40 @@ class AlignmentAttentionModule(nn.Module): pos_emb = self.pos_encode(merged_lm_pruned) attn_weights = self.cross_attn_weights( - merged_lm_pruned, merged_am_pruned, pos_emb + merged_lm_pruned, + merged_am_pruned, + pos_emb, + key_padding_mask=src_key_padding_mask, ) # (num_heads, b_p_dim, lm_seq_len, am_seq_len) - # print("attn_weights.shape", attn_weights.shape) label_level_am_representation = self.cross_attn( merged_am_pruned, attn_weights ) - # print( - # "label_level_am_representation.shape", - # label_level_am_representation.shape, - # ) - # (lm_seq_len, batch_size * prune_range, encoder_dim) return label_level_am_representation.reshape( lm_T, batch_size, prune_range, encoder_dim ).permute(1, 0, 2, 3) - elif len(am_pruned.shape) == 3 and len(lm_pruned.shape) == 3: - # am_pruned : [am_T, B, encoder_dim] - # lm_pruned : [lm_T, B, decoder_dim] - (am_T, batch_size, encoder_dim) = am_pruned.shape - (lm_T, batch_size, decoder_dim) = lm_pruned.shape + # elif len(am_pruned.shape) == 3 and len(lm_pruned.shape) == 3: + # am_pruned = am_pruned.permute(1, 0, 2) + # lm_pruned = lm_pruned.permute(1, 0, 2) - pos_emb = self.pos_encode(lm_pruned) + # # am_pruned : [am_T, B, encoder_dim] + # # lm_pruned : [lm_T, B, decoder_dim] + # (am_T, batch_size, encoder_dim) = am_pruned.shape + # (lm_T, batch_size, decoder_dim) = lm_pruned.shape - attn_weights = self.cross_attn_weights(lm_pruned, am_pruned, pos_emb) - label_level_am_representation = self.cross_attn(am_pruned, attn_weights) - # (T, batch_size, encoder_dim) + # pos_emb = self.pos_encode(lm_pruned) - return label_level_am_representation + # attn_weights = self.cross_attn_weights( + # lm_pruned, + # am_pruned, + # pos_emb, + # key_padding_mask=src_key_padding_mask, + # ) + # label_level_am_representation = self.cross_attn(am_pruned, attn_weights) + # # (T, batch_size, encoder_dim) + + # return label_level_am_representation else: raise NotImplementedError("Dim Error") @@ -526,7 +561,7 @@ if __name__ == "__main__": # attn_weights = weights(lm, am, pos_emb) # print("weights(am_pruned, lm_pruned, pos_emb).shape", attn_weights.shape) # res = attn(am, attn_weights) - res = attn(am, lm) + res = attn(am, lm, torch.Tensor([70, 80])) print("__main__ res", res.shape) print("__main__ === for training ===") @@ -534,5 +569,6 @@ if __name__ == "__main__": # lm_pruned : [B, T, prune_range, decoder_dim] am_pruned = torch.rand(2, 100, 5, 512) lm_pruned = torch.rand(2, 100, 5, 512) - res = attn(am_pruned, lm_pruned) + lengths = Tensor([100, 100]) + res = attn(am_pruned, lm_pruned, lengths) print("__main__ res", res.shape) diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py b/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py index 10d584eb4..040f4f40e 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py @@ -39,6 +39,7 @@ class Joiner(nn.Module): self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, + lengths: torch.Tensor, apply_attn: bool = True, project_input: bool = True, ) -> torch.Tensor: @@ -62,8 +63,10 @@ class Joiner(nn.Module): decoder_out.shape, ) - if apply_attn: - encoder_out = self.label_level_am_attention(encoder_out, decoder_out) + if apply_attn and lengths is not None: + encoder_out = self.label_level_am_attention( + encoder_out, decoder_out, lengths + ) if project_input: logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/model.py b/egs/librispeech/ASR/zipformer_label_level_algn/model.py index a7ce4e495..d4ed9441c 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/model.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/model.py @@ -264,7 +264,9 @@ class AsrModel(nn.Module): # project_input=False since we applied the decoder's input projections # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) + logits = self.joiner( + am_pruned, lm_pruned, encoder_out_lens, project_input=False + ) with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned(