mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement attention weights sharing for successive layers, for Zipformer
This commit is contained in:
parent
121f7e2a45
commit
f483f1e0ef
@ -133,6 +133,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-share-layers",
|
||||
type=str,
|
||||
default="2",
|
||||
help="Number of layers that share attention weights within each zipformer stack: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dim",
|
||||
type=str,
|
||||
@ -488,6 +495,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
value_head_dim=to_int_tuple(params.value_head_dim),
|
||||
pos_dim=params.pos_dim,
|
||||
num_heads=to_int_tuple(params.num_heads),
|
||||
attention_share_layers=to_int_tuple(params.attention_share_layers),
|
||||
feedforward_dim=to_int_tuple(params.feedforward_dim),
|
||||
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
|
||||
dropout=0.1,
|
||||
|
||||
@ -76,6 +76,8 @@ class Zipformer(EncoderInterface):
|
||||
attention head
|
||||
num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
|
||||
Must be at least 4.
|
||||
attention_share_layers: (int or Tuple[int]): how many successive layers share
|
||||
the same attention weights. Must be at least 1.
|
||||
feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
|
||||
cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
|
||||
|
||||
@ -99,6 +101,7 @@ class Zipformer(EncoderInterface):
|
||||
pos_head_dim: Union[int, Tuple[int]] = 4,
|
||||
value_head_dim: Union[int, Tuple[int]] = 12,
|
||||
num_heads: Union[int, Tuple[int]] = 8,
|
||||
attention_share_layers: Union[int, Tuple[int]] = 2,
|
||||
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
||||
cnn_module_kernel: Union[int, Tuple[int]] = 31,
|
||||
pos_dim: int = 192,
|
||||
@ -135,6 +138,7 @@ class Zipformer(EncoderInterface):
|
||||
value_head_dim = _to_tuple(value_head_dim)
|
||||
pos_head_dim = _to_tuple(pos_head_dim)
|
||||
num_heads = _to_tuple(num_heads)
|
||||
attention_share_layers = _to_tuple(attention_share_layers)
|
||||
feedforward_dim = _to_tuple(feedforward_dim)
|
||||
cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
||||
|
||||
@ -179,7 +183,8 @@ class Zipformer(EncoderInterface):
|
||||
pos_dim=pos_dim,
|
||||
dropout=dropout,
|
||||
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
||||
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1)
|
||||
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
||||
attention_share_layers=attention_share_layers[i],
|
||||
)
|
||||
|
||||
if downsampling_factor[i] != 1:
|
||||
@ -442,6 +447,9 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
def remove_attention_weights(self):
|
||||
self.self_attn_weights = None
|
||||
|
||||
def get_bypass_scale(self):
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
return self.bypass_scale
|
||||
@ -456,7 +464,8 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
pos_emb: Tensor,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
attn_weights: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
|
||||
@ -465,7 +474,8 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
batch_split: if not None, this layer will only be applied to
|
||||
attn_weights: possibly attention weights computed by the previous layer,
|
||||
to be used if self.self_attn_weights is None
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
@ -473,27 +483,32 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
src_mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
if self.training and random.random() < float(self.layer_skip_rate):
|
||||
# skip the layer
|
||||
return src
|
||||
|
||||
Returns:
|
||||
(x, attn_weights) where x has the same shape as src, and attn_weights are of
|
||||
shape (num_heads, batch_size, seq_len, seq_len).
|
||||
"""
|
||||
src_orig = src
|
||||
|
||||
# dropout rate for non-feedforward submodules
|
||||
dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0
|
||||
# multi-headed self-attention module
|
||||
use_self_attn = (random.random() >= dynamic_skip_rate)
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||
if self.self_attn_weights is not None:
|
||||
attn_weights = self.self_attn_weights(
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
# else rely on the ones passed in
|
||||
|
||||
if self.training and random.random() < float(self.layer_skip_rate):
|
||||
# skip the layer
|
||||
return src, attn_weights
|
||||
|
||||
use_self_attn = (random.random() >= dynamic_skip_rate)
|
||||
if use_self_attn:
|
||||
first_attn_weights = attn_weights[0:3]
|
||||
if random.random() < float(self.const_attention_rate):
|
||||
# Make attention weights constant. The intention is to
|
||||
@ -530,8 +545,9 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
delta = src - src_orig
|
||||
|
||||
src = src_orig + delta * self.get_bypass_scale()
|
||||
src = self.whiten(src)
|
||||
|
||||
return self.whiten(src)
|
||||
return src, attn_weights
|
||||
|
||||
|
||||
class ZipformerEncoder(nn.Module):
|
||||
@ -558,6 +574,7 @@ class ZipformerEncoder(nn.Module):
|
||||
warmup_end: float,
|
||||
initial_layerdrop_rate: float = 0.5,
|
||||
final_layerdrop_rate: float = 0.05,
|
||||
attention_share_layers: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15,
|
||||
@ -579,7 +596,8 @@ class ZipformerEncoder(nn.Module):
|
||||
(cur_end, final_layerdrop_rate),
|
||||
default=0.0)
|
||||
cur_begin = cur_end
|
||||
|
||||
if i % attention_share_layers != 0:
|
||||
self.layers[i].remove_attention_weights()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -614,12 +632,15 @@ class ZipformerEncoder(nn.Module):
|
||||
|
||||
output = output * feature_mask
|
||||
|
||||
attn_weights = None
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
output = mod(
|
||||
output, attn_weights = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
attn_weights=attn_weights,
|
||||
)
|
||||
|
||||
output = output * feature_mask
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user