Combine two layers into one.

This commit is contained in:
Daniel Povey 2023-04-04 12:14:18 +08:00
parent 3dd25d6b2d
commit 7ab1e7f5ec
2 changed files with 68 additions and 84 deletions

View File

@ -124,7 +124,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-encoder-layers",
type=str,
default="2,4,4,6,4,4",
default="1,2,2,3,2,2",
help="Number of zipformer encoder layers per stack, comma separated.",
)
@ -151,13 +151,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
)
parser.add_argument(
"--attention-share-layers",
type=str,
default="2",
help="Number of layers that share attention weights within each zipformer stack: a single int or comma-separated list.",
)
parser.add_argument(
"--encoder-dim",
type=str,
@ -548,7 +541,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
value_head_dim=to_int_tuple(params.value_head_dim),
pos_dim=params.pos_dim,
num_heads=to_int_tuple(params.num_heads),
attention_share_layers=to_int_tuple(params.attention_share_layers),
feedforward_dim=to_int_tuple(params.feedforward_dim),
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),

View File

@ -80,8 +80,6 @@ class Zipformer2(EncoderInterface):
attention head
num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
Must be at least 4.
attention_share_layers: (int or Tuple[int]): how many successive layers share
the same attention weights. Must be at least 1.
feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
@ -115,7 +113,6 @@ class Zipformer2(EncoderInterface):
pos_head_dim: Union[int, Tuple[int]] = 4,
value_head_dim: Union[int, Tuple[int]] = 12,
num_heads: Union[int, Tuple[int]] = 8,
attention_share_layers: Union[int, Tuple[int]] = 2,
feedforward_dim: Union[int, Tuple[int]] = 1536,
cnn_module_kernel: Union[int, Tuple[int]] = 31,
pos_dim: int = 192,
@ -160,7 +157,6 @@ class Zipformer2(EncoderInterface):
value_head_dim = _to_tuple(value_head_dim)
pos_head_dim = _to_tuple(pos_head_dim)
num_heads = _to_tuple(num_heads)
attention_share_layers = _to_tuple(attention_share_layers)
feedforward_dim = _to_tuple(feedforward_dim)
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
@ -212,7 +208,6 @@ class Zipformer2(EncoderInterface):
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
attention_share_layers=attention_share_layers[i],
)
if downsampling_factor[i] != 1:
@ -297,49 +292,33 @@ class Zipformer2(EncoderInterface):
# largest of the downsampling_factors. Can be any integer >= 1.
downsampling_multiple = 4
downsampling_factor = [ downsampling_multiple * i for i in self.downsampling_factor ]
group_size = max(self.downsampling_factor) * downsampling_multiple
max_downsampling_factor = max(downsampling_factor)
num_groups = (num_frames0 + group_size - 1) // group_size
num_frames_max = (num_frames0 + max_downsampling_factor - 1) // max_downsampling_factor
feature_mask_dropout_prob = 0.2
# we divide the dropped-out feature dimensions into two equal groups;
# the first group is dropped out with probability 0.1, the second
# with probability approximately twice that.
feature_mask_dropout_prob = 0.125
# shape: (num_groups, batch_size, 1)
group_mask = (torch.rand(num_groups, batch_size, 1,
device=x.device) >
feature_mask_dropout_prob).to(x.dtype)
# frame_mask_max1 shape: (num_frames_max, batch_size, 1)
frame_mask_max1 = (torch.rand(num_frames_max, batch_size, 1,
device=x.device) >
feature_mask_dropout_prob).to(x.dtype)
# frame_mask_max2 has additional frames masked, about twice the number.
frame_mask_max2 = torch.logical_and(frame_mask_max1,
(torch.rand(num_frames_max, batch_size, 1,
device=x.device) >
feature_mask_dropout_prob).to(x.dtype))
# dim: (num_frames_max, batch_size, 3)
frame_mask_max = torch.cat((frame_mask_max1, frame_mask_max2), dim=-1)
feature_masks = []
for i in range(num_encoders):
ds = self.downsampling_factor[i]
upsample_factor = (max_downsampling_factor // ds)
frames_per_group = (group_size // ds)
frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor,
batch_size, 2)
.reshape(num_frames_max * upsample_factor, batch_size, 2))
frame_mask = (group_mask.unsqueeze(1).expand(num_groups, frames_per_group,
batch_size, 1)
.reshape(num_groups * frames_per_group, batch_size, 1))
num_frames = (num_frames0 + ds - 1) // ds
frame_mask = frame_mask[:num_frames]
channels = self.encoder_dim[i]
feature_mask = torch.ones(num_frames, batch_size, channels,
dtype=x.dtype, device=x.device)
u1 = self.encoder_unmasked_dim[i]
u2 = u1 + (channels - u1) // 2
feature_mask[:, :, u1:u2] *= frame_mask[..., 0:1]
feature_mask[:, :, u2:] *= frame_mask[..., 1:2]
feature_mask[:, :, u1:] *= frame_mask
feature_masks.append(feature_mask)
return feature_masks
@ -547,7 +526,8 @@ class Zipformer2EncoderLayer(nn.Module):
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
ff2_skip_rate: FloatLike = 0.01,
ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0),
bypass_max: FloatLike = 1.0,
) -> None:
@ -565,6 +545,7 @@ class Zipformer2EncoderLayer(nn.Module):
# ff2_skip_rate is to prevent the ff2 module from having output that's too big
# compared to its residual.
self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
# ever becoming zero.
@ -578,7 +559,10 @@ class Zipformer2EncoderLayer(nn.Module):
dropout=0.0,
)
self.self_attn = SelfAttention(embed_dim, num_heads,
self.self_attn1 = SelfAttention(embed_dim, num_heads,
value_head_dim)
self.self_attn2 = SelfAttention(embed_dim, num_heads,
value_head_dim)
self.feed_forward1 = FeedforwardModule(embed_dim,
@ -586,16 +570,24 @@ class Zipformer2EncoderLayer(nn.Module):
dropout)
self.feed_forward2 = FeedforwardModule(embed_dim,
feedforward_dim,
dropout)
self.feed_forward3 = FeedforwardModule(embed_dim,
(feedforward_dim * 5) // 4,
dropout)
self.nonlin_attention = NonlinAttention(embed_dim,
hidden_channels=3 * embed_dim // 4)
self.conv_module = ConvolutionModule(embed_dim,
self.conv_module1 = ConvolutionModule(embed_dim,
cnn_module_kernel,
causal=causal)
self.conv_module2 = ConvolutionModule(embed_dim,
cnn_module_kernel,
causal=causal)
#self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
@ -617,14 +609,6 @@ class Zipformer2EncoderLayer(nn.Module):
prob=0.05, # out of concern for memory usage
)
# balancer for output of AttentionSqueezeModule
self.balancer_as = Balancer(
embed_dim, channel_dim=-1,
min_positive=0.3, max_positive=0.7,
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
prob=0.05, # out of concern for memory usage
)
# balancer for output of feedforward2, prevent it from staying too
# small. give this a very small probability, even at the start of
# training, it's to fix a rare problem and it's OK to fix it slowly.
@ -636,10 +620,19 @@ class Zipformer2EncoderLayer(nn.Module):
prob=0.05,
)
self.balancer_ff3 = Balancer(
embed_dim, channel_dim=-1,
min_positive=0.3, max_positive=0.7,
min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0),
max_abs=4.0,
prob=0.05,
)
self.whiten = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(4.0, ratio=3.0),
prob=(0.025, 0.25),
grad_scale=0.01)
self.balancer2 = Balancer(
embed_dim, channel_dim=-1,
min_positive=0.45, max_positive=0.55,
@ -647,9 +640,6 @@ class Zipformer2EncoderLayer(nn.Module):
)
def remove_attention_weights(self):
self.self_attn_weights = None
def get_bypass_scale(self, batch_size: int):
# returns bypass-scale of shape (num_channels,),
# or (batch_size, num_channels,). This is actually the
@ -696,8 +686,7 @@ class Zipformer2EncoderLayer(nn.Module):
chunk_size: int = -1,
attn_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
attn_weights: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
) -> Tensor:
"""
Pass the input through the encoder layer.
Args:
@ -713,8 +702,7 @@ class Zipformer2EncoderLayer(nn.Module):
masked position. May be None.
Returns:
(x, attn_weights) where x has the same shape as src, and attn_weights are of
shape (num_heads, batch_size, seq_len, seq_len).
A tensor which has the same shape as src
"""
src_orig = src
@ -722,24 +710,17 @@ class Zipformer2EncoderLayer(nn.Module):
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
if self.self_attn_weights is not None:
attn_weights = self.self_attn_weights(
src,
pos_emb=pos_emb,
attn_mask=attn_mask,
key_padding_mask=src_key_padding_mask,
attn_weights = self.self_attn_weights(
src,
pos_emb=pos_emb,
attn_mask=attn_mask,
key_padding_mask=src_key_padding_mask,
)
# else rely on the ones passed in
# use different heads for nonlin_attention and attention_squeeze, depending
# whether this module has its on self_attn_weights submodule or is borrowing
# attention weights from another one.
head_offset = 0 if self.self_attn_weights is not None else 2
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
if True:
selected_attn_weights = attn_weights[head_offset:head_offset+2]
selected_attn_weights = attn_weights[0:2]
if random.random() < float(self.const_attention_rate):
# Make attention weights constant. The intention is to
# encourage these modules to do something similar to an
@ -753,21 +734,38 @@ class Zipformer2EncoderLayer(nn.Module):
na = self.balancer_na(self.nonlin_attention(src,
selected_attn_weights[0:1]))
src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask)
src = src + self.feed_forward1(src)
self_attn = self.self_attn(
self_attn = self.self_attn1(
src, attn_weights)
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
src = src + self.sequence_dropout(self.conv_module(src, chunk_size=chunk_size,
src_key_padding_mask=src_key_padding_mask),
src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size,
src_key_padding_mask=src_key_padding_mask),
float(self.conv_skip_rate))
src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)),
float(self.ff2_skip_rate))
self_attn = self.self_attn2(
src, attn_weights)
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size,
src_key_padding_mask=src_key_padding_mask),
float(self.conv_skip_rate))
src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)),
float(self.ff3_skip_rate))
src = self.balancer1(src)
src = self.norm(src)
@ -779,7 +777,7 @@ class Zipformer2EncoderLayer(nn.Module):
src = self.balancer2(src)
src = self.whiten(src)
return src, attn_weights
return src
class Zipformer2Encoder(nn.Module):
r"""Zipformer2Encoder is a stack of N encoder layers
@ -805,7 +803,6 @@ class Zipformer2Encoder(nn.Module):
warmup_end: float,
initial_layerdrop_rate: float = 0.5,
final_layerdrop_rate: float = 0.05,
attention_share_layers: int = 1,
) -> None:
super().__init__()
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15,
@ -827,8 +824,6 @@ class Zipformer2Encoder(nn.Module):
(cur_end, final_layerdrop_rate),
default=0.0)
cur_begin = cur_end
if i % attention_share_layers != 0:
self.layers[i].remove_attention_weights()
def forward(
self,
@ -860,16 +855,13 @@ class Zipformer2Encoder(nn.Module):
output = output * feature_mask
attn_weights = None
for i, mod in enumerate(self.layers):
output, attn_weights = mod(
output = mod(
output,
pos_emb,
chunk_size=chunk_size,
attn_mask=attn_mask,
src_key_padding_mask=src_key_padding_mask,
attn_weights=attn_weights,
)
output = output * feature_mask