Implement attention weights sharing for successive layers, for Zipformer

This commit is contained in:
Daniel Povey 2022-11-28 13:36:16 +08:00
parent 121f7e2a45
commit f483f1e0ef
2 changed files with 43 additions and 14 deletions

View File

@ -133,6 +133,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
) )
parser.add_argument(
"--attention-share-layers",
type=str,
default="2",
help="Number of layers that share attention weights within each zipformer stack: a single int or comma-separated list.",
)
parser.add_argument( parser.add_argument(
"--encoder-dim", "--encoder-dim",
type=str, type=str,
@ -488,6 +495,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
value_head_dim=to_int_tuple(params.value_head_dim), value_head_dim=to_int_tuple(params.value_head_dim),
pos_dim=params.pos_dim, pos_dim=params.pos_dim,
num_heads=to_int_tuple(params.num_heads), num_heads=to_int_tuple(params.num_heads),
attention_share_layers=to_int_tuple(params.attention_share_layers),
feedforward_dim=to_int_tuple(params.feedforward_dim), feedforward_dim=to_int_tuple(params.feedforward_dim),
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel), cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
dropout=0.1, dropout=0.1,

View File

@ -76,6 +76,8 @@ class Zipformer(EncoderInterface):
attention head attention head
num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
Must be at least 4. Must be at least 4.
attention_share_layers: (int or Tuple[int]): how many successive layers share
the same attention weights. Must be at least 1.
feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
@ -99,6 +101,7 @@ class Zipformer(EncoderInterface):
pos_head_dim: Union[int, Tuple[int]] = 4, pos_head_dim: Union[int, Tuple[int]] = 4,
value_head_dim: Union[int, Tuple[int]] = 12, value_head_dim: Union[int, Tuple[int]] = 12,
num_heads: Union[int, Tuple[int]] = 8, num_heads: Union[int, Tuple[int]] = 8,
attention_share_layers: Union[int, Tuple[int]] = 2,
feedforward_dim: Union[int, Tuple[int]] = 1536, feedforward_dim: Union[int, Tuple[int]] = 1536,
cnn_module_kernel: Union[int, Tuple[int]] = 31, cnn_module_kernel: Union[int, Tuple[int]] = 31,
pos_dim: int = 192, pos_dim: int = 192,
@ -135,6 +138,7 @@ class Zipformer(EncoderInterface):
value_head_dim = _to_tuple(value_head_dim) value_head_dim = _to_tuple(value_head_dim)
pos_head_dim = _to_tuple(pos_head_dim) pos_head_dim = _to_tuple(pos_head_dim)
num_heads = _to_tuple(num_heads) num_heads = _to_tuple(num_heads)
attention_share_layers = _to_tuple(attention_share_layers)
feedforward_dim = _to_tuple(feedforward_dim) feedforward_dim = _to_tuple(feedforward_dim)
cnn_module_kernel = _to_tuple(cnn_module_kernel) cnn_module_kernel = _to_tuple(cnn_module_kernel)
@ -179,7 +183,8 @@ class Zipformer(EncoderInterface):
pos_dim=pos_dim, pos_dim=pos_dim,
dropout=dropout, dropout=dropout,
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1) warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
attention_share_layers=attention_share_layers[i],
) )
if downsampling_factor[i] != 1: if downsampling_factor[i] != 1:
@ -442,6 +447,9 @@ class ZipformerEncoderLayer(nn.Module):
prob=(0.025, 0.25), prob=(0.025, 0.25),
grad_scale=0.01) grad_scale=0.01)
def remove_attention_weights(self):
self.self_attn_weights = None
def get_bypass_scale(self): def get_bypass_scale(self):
if torch.jit.is_scripting() or not self.training: if torch.jit.is_scripting() or not self.training:
return self.bypass_scale return self.bypass_scale
@ -456,7 +464,8 @@ class ZipformerEncoderLayer(nn.Module):
pos_emb: Tensor, pos_emb: Tensor,
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor: attn_weights: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
""" """
Pass the input through the encoder layer. Pass the input through the encoder layer.
@ -465,7 +474,8 @@ class ZipformerEncoderLayer(nn.Module):
pos_emb: Positional embedding tensor (required). pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional). src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
batch_split: if not None, this layer will only be applied to attn_weights: possibly attention weights computed by the previous layer,
to be used if self.self_attn_weights is None
Shape: Shape:
src: (S, N, E). src: (S, N, E).
@ -473,27 +483,32 @@ class ZipformerEncoderLayer(nn.Module):
src_mask: (S, S). src_mask: (S, S).
src_key_padding_mask: (N, S). src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number S is the source sequence length, N is the batch size, E is the feature number
"""
if self.training and random.random() < float(self.layer_skip_rate):
# skip the layer
return src
Returns:
(x, attn_weights) where x has the same shape as src, and attn_weights are of
shape (num_heads, batch_size, seq_len, seq_len).
"""
src_orig = src src_orig = src
# dropout rate for non-feedforward submodules # dropout rate for non-feedforward submodules
dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0 dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0
# multi-headed self-attention module
use_self_attn = (random.random() >= dynamic_skip_rate)
if torch.jit.is_scripting() or use_self_attn: # attn_weights: (num_heads, batch_size, seq_len, seq_len)
# attn_weights: (num_heads, batch_size, seq_len, seq_len) if self.self_attn_weights is not None:
attn_weights = self.self_attn_weights( attn_weights = self.self_attn_weights(
src, src,
pos_emb=pos_emb, pos_emb=pos_emb,
attn_mask=src_mask, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask, key_padding_mask=src_key_padding_mask,
) )
# else rely on the ones passed in
if self.training and random.random() < float(self.layer_skip_rate):
# skip the layer
return src, attn_weights
use_self_attn = (random.random() >= dynamic_skip_rate)
if use_self_attn:
first_attn_weights = attn_weights[0:3] first_attn_weights = attn_weights[0:3]
if random.random() < float(self.const_attention_rate): if random.random() < float(self.const_attention_rate):
# Make attention weights constant. The intention is to # Make attention weights constant. The intention is to
@ -530,8 +545,9 @@ class ZipformerEncoderLayer(nn.Module):
delta = src - src_orig delta = src - src_orig
src = src_orig + delta * self.get_bypass_scale() src = src_orig + delta * self.get_bypass_scale()
src = self.whiten(src)
return self.whiten(src) return src, attn_weights
class ZipformerEncoder(nn.Module): class ZipformerEncoder(nn.Module):
@ -558,6 +574,7 @@ class ZipformerEncoder(nn.Module):
warmup_end: float, warmup_end: float,
initial_layerdrop_rate: float = 0.5, initial_layerdrop_rate: float = 0.5,
final_layerdrop_rate: float = 0.05, final_layerdrop_rate: float = 0.05,
attention_share_layers: int = 1,
) -> None: ) -> None:
super().__init__() super().__init__()
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15, self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15,
@ -579,7 +596,8 @@ class ZipformerEncoder(nn.Module):
(cur_end, final_layerdrop_rate), (cur_end, final_layerdrop_rate),
default=0.0) default=0.0)
cur_begin = cur_end cur_begin = cur_end
if i % attention_share_layers != 0:
self.layers[i].remove_attention_weights()
def forward( def forward(
self, self,
@ -614,12 +632,15 @@ class ZipformerEncoder(nn.Module):
output = output * feature_mask output = output * feature_mask
attn_weights = None
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
output = mod( output, attn_weights = mod(
output, output,
pos_emb, pos_emb,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
attn_weights=attn_weights,
) )
output = output * feature_mask output = output * feature_mask