Implement attention weights sharing for successive layers, for Zipformer

This commit is contained in:
Daniel Povey 2022-11-28 13:36:16 +08:00
parent 121f7e2a45
commit f483f1e0ef
2 changed files with 43 additions and 14 deletions

View File

@ -133,6 +133,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
)
parser.add_argument(
"--attention-share-layers",
type=str,
default="2",
help="Number of layers that share attention weights within each zipformer stack: a single int or comma-separated list.",
)
parser.add_argument(
"--encoder-dim",
type=str,
@ -488,6 +495,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
value_head_dim=to_int_tuple(params.value_head_dim),
pos_dim=params.pos_dim,
num_heads=to_int_tuple(params.num_heads),
attention_share_layers=to_int_tuple(params.attention_share_layers),
feedforward_dim=to_int_tuple(params.feedforward_dim),
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
dropout=0.1,

View File

@ -76,6 +76,8 @@ class Zipformer(EncoderInterface):
attention head
num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
Must be at least 4.
attention_share_layers: (int or Tuple[int]): how many successive layers share
the same attention weights. Must be at least 1.
feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
@ -99,6 +101,7 @@ class Zipformer(EncoderInterface):
pos_head_dim: Union[int, Tuple[int]] = 4,
value_head_dim: Union[int, Tuple[int]] = 12,
num_heads: Union[int, Tuple[int]] = 8,
attention_share_layers: Union[int, Tuple[int]] = 2,
feedforward_dim: Union[int, Tuple[int]] = 1536,
cnn_module_kernel: Union[int, Tuple[int]] = 31,
pos_dim: int = 192,
@ -135,6 +138,7 @@ class Zipformer(EncoderInterface):
value_head_dim = _to_tuple(value_head_dim)
pos_head_dim = _to_tuple(pos_head_dim)
num_heads = _to_tuple(num_heads)
attention_share_layers = _to_tuple(attention_share_layers)
feedforward_dim = _to_tuple(feedforward_dim)
cnn_module_kernel = _to_tuple(cnn_module_kernel)
@ -179,7 +183,8 @@ class Zipformer(EncoderInterface):
pos_dim=pos_dim,
dropout=dropout,
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1)
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
attention_share_layers=attention_share_layers[i],
)
if downsampling_factor[i] != 1:
@ -442,6 +447,9 @@ class ZipformerEncoderLayer(nn.Module):
prob=(0.025, 0.25),
grad_scale=0.01)
def remove_attention_weights(self):
self.self_attn_weights = None
def get_bypass_scale(self):
if torch.jit.is_scripting() or not self.training:
return self.bypass_scale
@ -456,7 +464,8 @@ class ZipformerEncoderLayer(nn.Module):
pos_emb: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
attn_weights: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""
Pass the input through the encoder layer.
@ -465,7 +474,8 @@ class ZipformerEncoderLayer(nn.Module):
pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
batch_split: if not None, this layer will only be applied to
attn_weights: possibly attention weights computed by the previous layer,
to be used if self.self_attn_weights is None
Shape:
src: (S, N, E).
@ -473,27 +483,32 @@ class ZipformerEncoderLayer(nn.Module):
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
if self.training and random.random() < float(self.layer_skip_rate):
# skip the layer
return src
Returns:
(x, attn_weights) where x has the same shape as src, and attn_weights are of
shape (num_heads, batch_size, seq_len, seq_len).
"""
src_orig = src
# dropout rate for non-feedforward submodules
dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0
# multi-headed self-attention module
use_self_attn = (random.random() >= dynamic_skip_rate)
if torch.jit.is_scripting() or use_self_attn:
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
if self.self_attn_weights is not None:
attn_weights = self.self_attn_weights(
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)
# else rely on the ones passed in
if self.training and random.random() < float(self.layer_skip_rate):
# skip the layer
return src, attn_weights
use_self_attn = (random.random() >= dynamic_skip_rate)
if use_self_attn:
first_attn_weights = attn_weights[0:3]
if random.random() < float(self.const_attention_rate):
# Make attention weights constant. The intention is to
@ -530,8 +545,9 @@ class ZipformerEncoderLayer(nn.Module):
delta = src - src_orig
src = src_orig + delta * self.get_bypass_scale()
src = self.whiten(src)
return self.whiten(src)
return src, attn_weights
class ZipformerEncoder(nn.Module):
@ -558,6 +574,7 @@ class ZipformerEncoder(nn.Module):
warmup_end: float,
initial_layerdrop_rate: float = 0.5,
final_layerdrop_rate: float = 0.05,
attention_share_layers: int = 1,
) -> None:
super().__init__()
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15,
@ -579,7 +596,8 @@ class ZipformerEncoder(nn.Module):
(cur_end, final_layerdrop_rate),
default=0.0)
cur_begin = cur_end
if i % attention_share_layers != 0:
self.layers[i].remove_attention_weights()
def forward(
self,
@ -614,12 +632,15 @@ class ZipformerEncoder(nn.Module):
output = output * feature_mask
attn_weights = None
for i, mod in enumerate(self.layers):
output = mod(
output, attn_weights = mod(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
attn_weights=attn_weights,
)
output = output * feature_mask