From f483f1e0ef905fa43b4d89c2615ccdefb447a6dc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 28 Nov 2022 13:36:16 +0800 Subject: [PATCH] Implement attention weights sharing for successive layers, for Zipformer --- .../ASR/pruned_transducer_stateless7/train.py | 8 +++ .../pruned_transducer_stateless7/zipformer.py | 49 +++++++++++++------ 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 9350465e7..b134c3b07 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -133,6 +133,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", ) + parser.add_argument( + "--attention-share-layers", + type=str, + default="2", + help="Number of layers that share attention weights within each zipformer stack: a single int or comma-separated list.", + ) + parser.add_argument( "--encoder-dim", type=str, @@ -488,6 +495,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: value_head_dim=to_int_tuple(params.value_head_dim), pos_dim=params.pos_dim, num_heads=to_int_tuple(params.num_heads), + attention_share_layers=to_int_tuple(params.attention_share_layers), feedforward_dim=to_int_tuple(params.feedforward_dim), cnn_module_kernel=to_int_tuple(params.cnn_module_kernel), dropout=0.1, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index c771b2895..6c7b73f2b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -76,6 +76,8 @@ class Zipformer(EncoderInterface): attention head num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. Must be at least 4. + attention_share_layers: (int or Tuple[int]): how many successive layers share + the same attention weights. Must be at least 1. feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module @@ -99,6 +101,7 @@ class Zipformer(EncoderInterface): pos_head_dim: Union[int, Tuple[int]] = 4, value_head_dim: Union[int, Tuple[int]] = 12, num_heads: Union[int, Tuple[int]] = 8, + attention_share_layers: Union[int, Tuple[int]] = 2, feedforward_dim: Union[int, Tuple[int]] = 1536, cnn_module_kernel: Union[int, Tuple[int]] = 31, pos_dim: int = 192, @@ -135,6 +138,7 @@ class Zipformer(EncoderInterface): value_head_dim = _to_tuple(value_head_dim) pos_head_dim = _to_tuple(pos_head_dim) num_heads = _to_tuple(num_heads) + attention_share_layers = _to_tuple(attention_share_layers) feedforward_dim = _to_tuple(feedforward_dim) cnn_module_kernel = _to_tuple(cnn_module_kernel) @@ -179,7 +183,8 @@ class Zipformer(EncoderInterface): pos_dim=pos_dim, dropout=dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1) + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + attention_share_layers=attention_share_layers[i], ) if downsampling_factor[i] != 1: @@ -442,6 +447,9 @@ class ZipformerEncoderLayer(nn.Module): prob=(0.025, 0.25), grad_scale=0.01) + def remove_attention_weights(self): + self.self_attn_weights = None + def get_bypass_scale(self): if torch.jit.is_scripting() or not self.training: return self.bypass_scale @@ -456,7 +464,8 @@ class ZipformerEncoderLayer(nn.Module): pos_emb: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: + attn_weights: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: """ Pass the input through the encoder layer. @@ -465,7 +474,8 @@ class ZipformerEncoderLayer(nn.Module): pos_emb: Positional embedding tensor (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - batch_split: if not None, this layer will only be applied to + attn_weights: possibly attention weights computed by the previous layer, + to be used if self.self_attn_weights is None Shape: src: (S, N, E). @@ -473,27 +483,32 @@ class ZipformerEncoderLayer(nn.Module): src_mask: (S, S). src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number - """ - if self.training and random.random() < float(self.layer_skip_rate): - # skip the layer - return src + Returns: + (x, attn_weights) where x has the same shape as src, and attn_weights are of + shape (num_heads, batch_size, seq_len, seq_len). + """ src_orig = src # dropout rate for non-feedforward submodules dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0 - # multi-headed self-attention module - use_self_attn = (random.random() >= dynamic_skip_rate) - if torch.jit.is_scripting() or use_self_attn: - # attn_weights: (num_heads, batch_size, seq_len, seq_len) + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + if self.self_attn_weights is not None: attn_weights = self.self_attn_weights( src, pos_emb=pos_emb, attn_mask=src_mask, key_padding_mask=src_key_padding_mask, ) + # else rely on the ones passed in + if self.training and random.random() < float(self.layer_skip_rate): + # skip the layer + return src, attn_weights + + use_self_attn = (random.random() >= dynamic_skip_rate) + if use_self_attn: first_attn_weights = attn_weights[0:3] if random.random() < float(self.const_attention_rate): # Make attention weights constant. The intention is to @@ -530,8 +545,9 @@ class ZipformerEncoderLayer(nn.Module): delta = src - src_orig src = src_orig + delta * self.get_bypass_scale() + src = self.whiten(src) - return self.whiten(src) + return src, attn_weights class ZipformerEncoder(nn.Module): @@ -558,6 +574,7 @@ class ZipformerEncoder(nn.Module): warmup_end: float, initial_layerdrop_rate: float = 0.5, final_layerdrop_rate: float = 0.05, + attention_share_layers: int = 1, ) -> None: super().__init__() self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15, @@ -579,7 +596,8 @@ class ZipformerEncoder(nn.Module): (cur_end, final_layerdrop_rate), default=0.0) cur_begin = cur_end - + if i % attention_share_layers != 0: + self.layers[i].remove_attention_weights() def forward( self, @@ -614,12 +632,15 @@ class ZipformerEncoder(nn.Module): output = output * feature_mask + attn_weights = None + for i, mod in enumerate(self.layers): - output = mod( + output, attn_weights = mod( output, pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, + attn_weights=attn_weights, ) output = output * feature_mask