diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index b58743944..b6ee6cdd4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -124,7 +124,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="2,4,4,6,4,4", + default="1,2,2,3,2,2", help="Number of zipformer encoder layers per stack, comma separated.", ) @@ -151,13 +151,6 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", ) - parser.add_argument( - "--attention-share-layers", - type=str, - default="2", - help="Number of layers that share attention weights within each zipformer stack: a single int or comma-separated list.", - ) - parser.add_argument( "--encoder-dim", type=str, @@ -548,7 +541,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: value_head_dim=to_int_tuple(params.value_head_dim), pos_dim=params.pos_dim, num_heads=to_int_tuple(params.num_heads), - attention_share_layers=to_int_tuple(params.attention_share_layers), feedforward_dim=to_int_tuple(params.feedforward_dim), cnn_module_kernel=to_int_tuple(params.cnn_module_kernel), dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 866e0a174..a549da5a6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -80,8 +80,6 @@ class Zipformer2(EncoderInterface): attention head num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. Must be at least 4. - attention_share_layers: (int or Tuple[int]): how many successive layers share - the same attention weights. Must be at least 1. feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module @@ -115,7 +113,6 @@ class Zipformer2(EncoderInterface): pos_head_dim: Union[int, Tuple[int]] = 4, value_head_dim: Union[int, Tuple[int]] = 12, num_heads: Union[int, Tuple[int]] = 8, - attention_share_layers: Union[int, Tuple[int]] = 2, feedforward_dim: Union[int, Tuple[int]] = 1536, cnn_module_kernel: Union[int, Tuple[int]] = 31, pos_dim: int = 192, @@ -160,7 +157,6 @@ class Zipformer2(EncoderInterface): value_head_dim = _to_tuple(value_head_dim) pos_head_dim = _to_tuple(pos_head_dim) num_heads = _to_tuple(num_heads) - attention_share_layers = _to_tuple(attention_share_layers) feedforward_dim = _to_tuple(feedforward_dim) self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) @@ -212,7 +208,6 @@ class Zipformer2(EncoderInterface): warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), - attention_share_layers=attention_share_layers[i], ) if downsampling_factor[i] != 1: @@ -297,49 +292,33 @@ class Zipformer2(EncoderInterface): # largest of the downsampling_factors. Can be any integer >= 1. downsampling_multiple = 4 - downsampling_factor = [ downsampling_multiple * i for i in self.downsampling_factor ] + group_size = max(self.downsampling_factor) * downsampling_multiple - max_downsampling_factor = max(downsampling_factor) + num_groups = (num_frames0 + group_size - 1) // group_size - num_frames_max = (num_frames0 + max_downsampling_factor - 1) // max_downsampling_factor + feature_mask_dropout_prob = 0.2 - # we divide the dropped-out feature dimensions into two equal groups; - # the first group is dropped out with probability 0.1, the second - # with probability approximately twice that. - feature_mask_dropout_prob = 0.125 + # shape: (num_groups, batch_size, 1) + group_mask = (torch.rand(num_groups, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype) - # frame_mask_max1 shape: (num_frames_max, batch_size, 1) - frame_mask_max1 = (torch.rand(num_frames_max, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype) - - # frame_mask_max2 has additional frames masked, about twice the number. - frame_mask_max2 = torch.logical_and(frame_mask_max1, - (torch.rand(num_frames_max, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype)) - - - # dim: (num_frames_max, batch_size, 3) - frame_mask_max = torch.cat((frame_mask_max1, frame_mask_max2), dim=-1) feature_masks = [] for i in range(num_encoders): ds = self.downsampling_factor[i] - upsample_factor = (max_downsampling_factor // ds) + frames_per_group = (group_size // ds) - frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, - batch_size, 2) - .reshape(num_frames_max * upsample_factor, batch_size, 2)) + frame_mask = (group_mask.unsqueeze(1).expand(num_groups, frames_per_group, + batch_size, 1) + .reshape(num_groups * frames_per_group, batch_size, 1)) num_frames = (num_frames0 + ds - 1) // ds frame_mask = frame_mask[:num_frames] channels = self.encoder_dim[i] feature_mask = torch.ones(num_frames, batch_size, channels, dtype=x.dtype, device=x.device) u1 = self.encoder_unmasked_dim[i] - u2 = u1 + (channels - u1) // 2 - feature_mask[:, :, u1:u2] *= frame_mask[..., 0:1] - feature_mask[:, :, u2:] *= frame_mask[..., 1:2] + feature_mask[:, :, u1:] *= frame_mask feature_masks.append(feature_mask) return feature_masks @@ -547,7 +526,8 @@ class Zipformer2EncoderLayer(nn.Module): attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), - ff2_skip_rate: FloatLike = 0.01, + ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), + ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0), bypass_max: FloatLike = 1.0, ) -> None: @@ -565,6 +545,7 @@ class Zipformer2EncoderLayer(nn.Module): # ff2_skip_rate is to prevent the ff2 module from having output that's too big # compared to its residual. self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) + self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) # min and max for self.bypass_scale, applied with probability 0.5 to avoid grads # ever becoming zero. @@ -578,7 +559,10 @@ class Zipformer2EncoderLayer(nn.Module): dropout=0.0, ) - self.self_attn = SelfAttention(embed_dim, num_heads, + self.self_attn1 = SelfAttention(embed_dim, num_heads, + value_head_dim) + + self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim) self.feed_forward1 = FeedforwardModule(embed_dim, @@ -586,16 +570,24 @@ class Zipformer2EncoderLayer(nn.Module): dropout) self.feed_forward2 = FeedforwardModule(embed_dim, + feedforward_dim, + dropout) + + self.feed_forward3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4, dropout) self.nonlin_attention = NonlinAttention(embed_dim, hidden_channels=3 * embed_dim // 4) - self.conv_module = ConvolutionModule(embed_dim, + self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel, causal=causal) + self.conv_module2 = ConvolutionModule(embed_dim, + cnn_module_kernel, + causal=causal) + #self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2) @@ -617,14 +609,6 @@ class Zipformer2EncoderLayer(nn.Module): prob=0.05, # out of concern for memory usage ) - # balancer for output of AttentionSqueezeModule - self.balancer_as = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.3, max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), - prob=0.05, # out of concern for memory usage - ) - # balancer for output of feedforward2, prevent it from staying too # small. give this a very small probability, even at the start of # training, it's to fix a rare problem and it's OK to fix it slowly. @@ -636,10 +620,19 @@ class Zipformer2EncoderLayer(nn.Module): prob=0.05, ) + self.balancer_ff3 = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.3, max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), + max_abs=4.0, + prob=0.05, + ) + self.whiten = Whiten(num_groups=1, whitening_limit=_whitening_schedule(4.0, ratio=3.0), prob=(0.025, 0.25), grad_scale=0.01) + self.balancer2 = Balancer( embed_dim, channel_dim=-1, min_positive=0.45, max_positive=0.55, @@ -647,9 +640,6 @@ class Zipformer2EncoderLayer(nn.Module): ) - def remove_attention_weights(self): - self.self_attn_weights = None - def get_bypass_scale(self, batch_size: int): # returns bypass-scale of shape (num_channels,), # or (batch_size, num_channels,). This is actually the @@ -696,8 +686,7 @@ class Zipformer2EncoderLayer(nn.Module): chunk_size: int = -1, attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - attn_weights: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: + ) -> Tensor: """ Pass the input through the encoder layer. Args: @@ -713,8 +702,7 @@ class Zipformer2EncoderLayer(nn.Module): masked position. May be None. Returns: - (x, attn_weights) where x has the same shape as src, and attn_weights are of - shape (num_heads, batch_size, seq_len, seq_len). + A tensor which has the same shape as src """ src_orig = src @@ -722,24 +710,17 @@ class Zipformer2EncoderLayer(nn.Module): attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 # attn_weights: (num_heads, batch_size, seq_len, seq_len) - if self.self_attn_weights is not None: - attn_weights = self.self_attn_weights( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, ) - # else rely on the ones passed in - - # use different heads for nonlin_attention and attention_squeeze, depending - # whether this module has its on self_attn_weights submodule or is borrowing - # attention weights from another one. - head_offset = 0 if self.self_attn_weights is not None else 2 self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) if True: - selected_attn_weights = attn_weights[head_offset:head_offset+2] + selected_attn_weights = attn_weights[0:2] if random.random() < float(self.const_attention_rate): # Make attention weights constant. The intention is to # encourage these modules to do something similar to an @@ -753,21 +734,38 @@ class Zipformer2EncoderLayer(nn.Module): na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights[0:1])) + src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask) src = src + self.feed_forward1(src) - self_attn = self.self_attn( + self_attn = self.self_attn1( src, attn_weights) + src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) - src = src + self.sequence_dropout(self.conv_module(src, chunk_size=chunk_size, - src_key_padding_mask=src_key_padding_mask), + src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size, + src_key_padding_mask=src_key_padding_mask), float(self.conv_skip_rate)) src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)), float(self.ff2_skip_rate)) + self_attn = self.self_attn2( + src, attn_weights) + + src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) + + src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size, + src_key_padding_mask=src_key_padding_mask), + float(self.conv_skip_rate)) + + + src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)), + float(self.ff3_skip_rate)) + + + src = self.balancer1(src) src = self.norm(src) @@ -779,7 +777,7 @@ class Zipformer2EncoderLayer(nn.Module): src = self.balancer2(src) src = self.whiten(src) - return src, attn_weights + return src class Zipformer2Encoder(nn.Module): r"""Zipformer2Encoder is a stack of N encoder layers @@ -805,7 +803,6 @@ class Zipformer2Encoder(nn.Module): warmup_end: float, initial_layerdrop_rate: float = 0.5, final_layerdrop_rate: float = 0.05, - attention_share_layers: int = 1, ) -> None: super().__init__() self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15, @@ -827,8 +824,6 @@ class Zipformer2Encoder(nn.Module): (cur_end, final_layerdrop_rate), default=0.0) cur_begin = cur_end - if i % attention_share_layers != 0: - self.layers[i].remove_attention_weights() def forward( self, @@ -860,16 +855,13 @@ class Zipformer2Encoder(nn.Module): output = output * feature_mask - attn_weights = None - for i, mod in enumerate(self.layers): - output, attn_weights = mod( + output = mod( output, pos_emb, chunk_size=chunk_size, attn_mask=attn_mask, src_key_padding_mask=src_key_padding_mask, - attn_weights=attn_weights, ) output = output * feature_mask