From 9971ca61733b20ac00f2559273d5d0ddbe7afc5d Mon Sep 17 00:00:00 2001 From: Yifan Yang Date: Wed, 14 Jun 2023 15:22:00 +0800 Subject: [PATCH] Add cat --- egs/librispeech/ASR/zipformer/train.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 156 +++++++++++++++------ 2 files changed, 111 insertions(+), 47 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index bec9a3986..0decc96a5 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1185,7 +1185,7 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + if 0 and not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 8d90198fd..0e870e3e7 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -544,13 +544,15 @@ class Zipformer2EncoderLayer(nn.Module): bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0), ) -> None: super(Zipformer2EncoderLayer, self).__init__() + embed_dim = embed_dim >> 1 self.embed_dim = embed_dim # self.bypass implements layer skipping as well as bypass; see its default values. - self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate, + self.bypass = BypassModule(embed_dim * 2, skip_rate=bypass_skip_rate, straight_through_rate=0) # bypass_mid is bypass used in the middle of the layer. - self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) + self.bypass_mid_a = BypassModule(embed_dim, straight_through_rate=0) + self.bypass_mid_b = BypassModule(embed_dim, straight_through_rate=0) # skip probability for dynamic modules (meaning: anything but feedforward). self.attention_skip_rate = copy.deepcopy(attention_skip_rate) @@ -565,48 +567,71 @@ class Zipformer2EncoderLayer(nn.Module): self.const_attention_rate = copy.deepcopy(const_attention_rate) - self.self_attn_weights = RelPositionMultiheadAttentionWeights( + self.cross_attn_weights_a = RelPositionMultiheadAttentionWeights( + embed_dim, pos_dim=pos_dim, num_heads=num_heads, + query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, + dropout=0.0, + ) + self.cross_attn_weights_b = RelPositionMultiheadAttentionWeights( embed_dim, pos_dim=pos_dim, num_heads=num_heads, query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, dropout=0.0, ) - self.self_attn1 = SelfAttention(embed_dim, num_heads, + self.cross_attn1_a = CrossAttention(embed_dim, num_heads, + value_head_dim) + self.cross_attn1_b = CrossAttention(embed_dim, num_heads, value_head_dim) - self.self_attn2 = SelfAttention(embed_dim, num_heads, + self.cross_attn2_a = CrossAttention(embed_dim, num_heads, + value_head_dim) + self.cross_attn2_b = CrossAttention(embed_dim, num_heads, value_head_dim) - self.feed_forward1 = FeedforwardModule(embed_dim, + self.feed_forward1_a = FeedforwardModule(embed_dim, + (feedforward_dim * 3) // 4, + dropout) + self.feed_forward1_b = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, dropout) - self.feed_forward2 = FeedforwardModule(embed_dim, + self.feed_forward2_a = FeedforwardModule(embed_dim, + feedforward_dim, + dropout) + self.feed_forward2_b = FeedforwardModule(embed_dim, feedforward_dim, dropout) - self.feed_forward3 = FeedforwardModule(embed_dim, + self.feed_forward3_a = FeedforwardModule(embed_dim, + (feedforward_dim * 5) // 4, + dropout) + self.feed_forward3_b = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4, dropout) - self.nonlin_attention = NonlinAttention(embed_dim, + self.nonlin_attention_a = NonlinAttention(embed_dim, + hidden_channels=3 * embed_dim // 4) + self.nonlin_attention_b = NonlinAttention(embed_dim, hidden_channels=3 * embed_dim // 4) - self.conv_module1 = ConvolutionModule(embed_dim, + self.conv_module1_a = ConvolutionModule(embed_dim, + cnn_module_kernel, + causal=causal) + self.conv_module1_b = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) - self.conv_module2 = ConvolutionModule(embed_dim, + self.conv_module2_a = ConvolutionModule(embed_dim, + cnn_module_kernel, + causal=causal) + self.conv_module2_b = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) - # TODO: remove it - self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - - self.norm = BiasNorm(embed_dim) + self.norm = BiasNorm(embed_dim * 2) self.balancer1 = Balancer( - embed_dim, channel_dim=-1, + embed_dim * 2, channel_dim=-1, min_positive=0.45, max_positive=0.55, min_abs=0.2, max_abs=4.0, ) @@ -644,7 +669,7 @@ class Zipformer2EncoderLayer(nn.Module): grad_scale=0.01) self.balancer2 = Balancer( - embed_dim, channel_dim=-1, + embed_dim * 2, channel_dim=-1, min_positive=0.45, max_positive=0.55, min_abs=0.1, max_abs=4.0, ) @@ -692,7 +717,8 @@ class Zipformer2EncoderLayer(nn.Module): Returns: A tensor which has the same shape as src """ - src_orig = src + src_a, src_b = torch.split(src, self.embed_dim, 2) + src_orig_a, src_orig_b = src_a, src_b # dropout rate for non-feedforward submodules if torch.jit.is_scripting(): @@ -701,18 +727,28 @@ class Zipformer2EncoderLayer(nn.Module): attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights = self.self_attn_weights( - src, + attn_weights_a = self.cross_attn_weights_a( + src_a, + src_b, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + attn_weights_b = self.cross_attn_weights_b( + src_b, + src_a, pos_emb=pos_emb, attn_mask=attn_mask, key_padding_mask=src_key_padding_mask, ) - src = src + self.feed_forward1(src) + src_a = src_a + self.feed_forward1_a(src_a) + src_b = src_b + self.feed_forward1_b(src_b) - self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) + cross_attn_dropout_mask = self.get_sequence_dropout_mask(src_a, attention_skip_rate) - selected_attn_weights = attn_weights[0:1] + selected_attn_weights_a = attn_weights_a[0:1] + selected_attn_weights_b = attn_weights_b[0:1] if torch.jit.is_scripting(): pass elif not self.training and random.random() < float(self.const_attention_rate): @@ -720,23 +756,33 @@ class Zipformer2EncoderLayer(nn.Module): # encourage these modules to do something similar to an # averaging-over-time operation. # only need the mask, can just use the 1st one and expand later - selected_attn_weights = selected_attn_weights[0:1] - selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype) - selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) + selected_attn_weights_a = selected_attn_weights_a[0:1] + selected_attn_weights_b = selected_attn_weights_b[0:1] + selected_attn_weights_a = (selected_attn_weights_a > 0.0).to(selected_attn_weights_a.dtype) + selected_attn_weights_b = (selected_attn_weights_b > 0.0).to(selected_attn_weights_b.dtype) + selected_attn_weights_a = selected_attn_weights_a * (1.0 / selected_attn_weights_a.sum(dim=-1, keepdim=True)) + selected_attn_weights_b = selected_attn_weights_b * (1.0 / selected_attn_weights_b.sum(dim=-1, keepdim=True)) - na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) + na_a = self.balancer_na(self.nonlin_attention_a(src_a, selected_attn_weights_a)) + na_b = self.balancer_na(self.nonlin_attention_b(src_b, selected_attn_weights_b)) - src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask) + src_a = src_a + (na_a if cross_attn_dropout_mask is None else na_a * cross_attn_dropout_mask) + src_b = src_b + (na_b if cross_attn_dropout_mask is None else na_b * cross_attn_dropout_mask) - self_attn = self.self_attn1(src, attn_weights) + cross_attn_a = self.cross_attn1_a(src_b, attn_weights_a) + cross_attn_b = self.cross_attn1_b(src_a, attn_weights_b) - src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) + src_a = src_a + (cross_attn_a if cross_attn_dropout_mask is None else cross_attn_a * cross_attn_dropout_mask) + src_b = src_b + (cross_attn_b if cross_attn_dropout_mask is None else cross_attn_b * cross_attn_dropout_mask) if torch.jit.is_scripting(): conv_skip_rate = 0.0 else: conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size, + src_a = src_a + self.sequence_dropout(self.conv_module1_a(src_a, chunk_size=chunk_size, + src_key_padding_mask=src_key_padding_mask), + conv_skip_rate) + src_b = src_b + self.sequence_dropout(self.conv_module1_b(src_b, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask), conv_skip_rate) @@ -744,21 +790,29 @@ class Zipformer2EncoderLayer(nn.Module): ff2_skip_rate = 0.0 else: ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)), + src_a = src_a + self.sequence_dropout(self.balancer_ff2(self.feed_forward2_a(src_a)), + ff2_skip_rate) + src_b = src_b + self.sequence_dropout(self.balancer_ff2(self.feed_forward2_b(src_b)), ff2_skip_rate) # bypass in the middle of the layer. - src = self.bypass_mid(src_orig, src) + src_a = self.bypass_mid_a(src_orig_a, src_a) + src_b = self.bypass_mid_b(src_orig_b, src_b) - self_attn = self.self_attn2(src, attn_weights) + cross_attn_a = self.cross_attn2_a(src_b, attn_weights_a) + cross_attn_b = self.cross_attn2_b(src_a, attn_weights_b) - src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) + src_a = src_a + (cross_attn_a if cross_attn_dropout_mask is None else cross_attn_a * cross_attn_dropout_mask) + src_b = src_b + (cross_attn_b if cross_attn_dropout_mask is None else cross_attn_b * cross_attn_dropout_mask) if torch.jit.is_scripting(): conv_skip_rate = 0.0 else: conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size, + src_a = src_a + self.sequence_dropout(self.conv_module2_a(src_a, chunk_size=chunk_size, + src_key_padding_mask=src_key_padding_mask), + conv_skip_rate) + src_b = src_b + self.sequence_dropout(self.conv_module2_b(src_b, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask), conv_skip_rate) @@ -766,12 +820,17 @@ class Zipformer2EncoderLayer(nn.Module): ff3_skip_rate = 0.0 else: ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)), + src_a = src_a + self.sequence_dropout(self.balancer_ff3(self.feed_forward3_a(src_a)), ff3_skip_rate) + src_b = src_b + self.sequence_dropout(self.balancer_ff3(self.feed_forward3_b(src_b)), + ff3_skip_rate) + + src = torch.cat([src_a, src_b], 2) src = self.balancer1(src) src = self.norm(src) + src_orig = torch.cat([src_orig_a, src_orig_b], 2) src = self.bypass(src_orig, src) src = self.balancer2(src) @@ -827,7 +886,7 @@ class Zipformer2EncoderLayer(nn.Module): src_orig = src # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights, cached_key = self.self_attn_weights.streaming_forward( + attn_weights, cached_key = self.cross_attn_weights.streaming_forward( src, pos_emb=pos_emb, cached_key=cached_key, @@ -845,13 +904,13 @@ class Zipformer2EncoderLayer(nn.Module): ) src = src + na - self_attn, cached_val1 = self.self_attn1.streaming_forward( + cross_attn, cached_val1 = self.cross_attn1.streaming_forward( src, attn_weights=attn_weights, cached_val=cached_val1, left_context_len=left_context_len, ) - src = src + self_attn + src = src + cross_attn src_conv, cached_conv1 = self.conv_module1.streaming_forward( src, @@ -865,13 +924,13 @@ class Zipformer2EncoderLayer(nn.Module): # bypass in the middle of the layer. src = self.bypass_mid(src_orig, src) - self_attn, cached_val2 = self.self_attn2.streaming_forward( + cross_attn, cached_val2 = self.cross_attn2.streaming_forward( src, attn_weights=attn_weights, cached_val=cached_val2, left_context_len=left_context_len, ) - src = src + self_attn + src = src + cross_attn src_conv, cached_conv2 = self.conv_module2.streaming_forward( src, @@ -1438,7 +1497,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # dividing it between the query and key. Note: this module is intended # to be used with the ScaledAdam optimizer; with most other optimizers, # it would be necessary to apply the scaling factor in the forward function. - self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, + self.in_proj_a = ScaledLinear(embed_dim, in_proj_dim, bias=True, + initial_scale=query_head_dim**-0.25) + self.in_proj_b = ScaledLinear(embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25) self.whiten_keys = Whiten(num_groups=num_heads, @@ -1475,6 +1536,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): def forward( self, x: Tensor, + y: Tensor, pos_emb: Tensor, key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, @@ -1482,6 +1544,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): r""" Args: x: input of shape (seq_len, batch_size, embed_dim) + y: input of shape (seq_len, batch_size, embed_dim) pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that are True in this mask will be ignored as sources in the attention weighting. @@ -1492,7 +1555,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module): a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). """ - x = self.in_proj(x) + x = self.in_proj_a(x) + y = self.in_proj_b(y) query_head_dim = self.query_head_dim pos_head_dim = self.pos_head_dim num_heads = self.num_heads @@ -1503,7 +1567,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # self-attention q = x[...,0:query_dim] - k = x[...,query_dim:2*query_dim] + k = y[...,query_dim:2*query_dim] # p is the position-encoding query p = x[...,2*query_dim:] assert p.shape[-1] == num_heads * pos_head_dim @@ -1711,7 +1775,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}") -class SelfAttention(nn.Module): +class CrossAttention(nn.Module): """ The simplest possible attention module. This one works with already-computed attention weights, e.g. as computed by RelPositionMultiheadAttentionWeights.