mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Combine two layers into one.
This commit is contained in:
parent
3dd25d6b2d
commit
7ab1e7f5ec
@ -124,7 +124,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-encoder-layers",
|
||||
type=str,
|
||||
default="2,4,4,6,4,4",
|
||||
default="1,2,2,3,2,2",
|
||||
help="Number of zipformer encoder layers per stack, comma separated.",
|
||||
)
|
||||
|
||||
@ -151,13 +151,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-share-layers",
|
||||
type=str,
|
||||
default="2",
|
||||
help="Number of layers that share attention weights within each zipformer stack: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dim",
|
||||
type=str,
|
||||
@ -548,7 +541,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
value_head_dim=to_int_tuple(params.value_head_dim),
|
||||
pos_dim=params.pos_dim,
|
||||
num_heads=to_int_tuple(params.num_heads),
|
||||
attention_share_layers=to_int_tuple(params.attention_share_layers),
|
||||
feedforward_dim=to_int_tuple(params.feedforward_dim),
|
||||
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
|
||||
@ -80,8 +80,6 @@ class Zipformer2(EncoderInterface):
|
||||
attention head
|
||||
num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
|
||||
Must be at least 4.
|
||||
attention_share_layers: (int or Tuple[int]): how many successive layers share
|
||||
the same attention weights. Must be at least 1.
|
||||
feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
|
||||
cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
|
||||
|
||||
@ -115,7 +113,6 @@ class Zipformer2(EncoderInterface):
|
||||
pos_head_dim: Union[int, Tuple[int]] = 4,
|
||||
value_head_dim: Union[int, Tuple[int]] = 12,
|
||||
num_heads: Union[int, Tuple[int]] = 8,
|
||||
attention_share_layers: Union[int, Tuple[int]] = 2,
|
||||
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
||||
cnn_module_kernel: Union[int, Tuple[int]] = 31,
|
||||
pos_dim: int = 192,
|
||||
@ -160,7 +157,6 @@ class Zipformer2(EncoderInterface):
|
||||
value_head_dim = _to_tuple(value_head_dim)
|
||||
pos_head_dim = _to_tuple(pos_head_dim)
|
||||
num_heads = _to_tuple(num_heads)
|
||||
attention_share_layers = _to_tuple(attention_share_layers)
|
||||
feedforward_dim = _to_tuple(feedforward_dim)
|
||||
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
||||
|
||||
@ -212,7 +208,6 @@ class Zipformer2(EncoderInterface):
|
||||
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
||||
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
||||
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
|
||||
attention_share_layers=attention_share_layers[i],
|
||||
)
|
||||
|
||||
if downsampling_factor[i] != 1:
|
||||
@ -297,49 +292,33 @@ class Zipformer2(EncoderInterface):
|
||||
# largest of the downsampling_factors. Can be any integer >= 1.
|
||||
downsampling_multiple = 4
|
||||
|
||||
downsampling_factor = [ downsampling_multiple * i for i in self.downsampling_factor ]
|
||||
group_size = max(self.downsampling_factor) * downsampling_multiple
|
||||
|
||||
max_downsampling_factor = max(downsampling_factor)
|
||||
num_groups = (num_frames0 + group_size - 1) // group_size
|
||||
|
||||
num_frames_max = (num_frames0 + max_downsampling_factor - 1) // max_downsampling_factor
|
||||
feature_mask_dropout_prob = 0.2
|
||||
|
||||
# we divide the dropped-out feature dimensions into two equal groups;
|
||||
# the first group is dropped out with probability 0.1, the second
|
||||
# with probability approximately twice that.
|
||||
feature_mask_dropout_prob = 0.125
|
||||
# shape: (num_groups, batch_size, 1)
|
||||
group_mask = (torch.rand(num_groups, batch_size, 1,
|
||||
device=x.device) >
|
||||
feature_mask_dropout_prob).to(x.dtype)
|
||||
|
||||
# frame_mask_max1 shape: (num_frames_max, batch_size, 1)
|
||||
frame_mask_max1 = (torch.rand(num_frames_max, batch_size, 1,
|
||||
device=x.device) >
|
||||
feature_mask_dropout_prob).to(x.dtype)
|
||||
|
||||
# frame_mask_max2 has additional frames masked, about twice the number.
|
||||
frame_mask_max2 = torch.logical_and(frame_mask_max1,
|
||||
(torch.rand(num_frames_max, batch_size, 1,
|
||||
device=x.device) >
|
||||
feature_mask_dropout_prob).to(x.dtype))
|
||||
|
||||
|
||||
# dim: (num_frames_max, batch_size, 3)
|
||||
frame_mask_max = torch.cat((frame_mask_max1, frame_mask_max2), dim=-1)
|
||||
|
||||
feature_masks = []
|
||||
for i in range(num_encoders):
|
||||
ds = self.downsampling_factor[i]
|
||||
upsample_factor = (max_downsampling_factor // ds)
|
||||
frames_per_group = (group_size // ds)
|
||||
|
||||
frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor,
|
||||
batch_size, 2)
|
||||
.reshape(num_frames_max * upsample_factor, batch_size, 2))
|
||||
frame_mask = (group_mask.unsqueeze(1).expand(num_groups, frames_per_group,
|
||||
batch_size, 1)
|
||||
.reshape(num_groups * frames_per_group, batch_size, 1))
|
||||
num_frames = (num_frames0 + ds - 1) // ds
|
||||
frame_mask = frame_mask[:num_frames]
|
||||
channels = self.encoder_dim[i]
|
||||
feature_mask = torch.ones(num_frames, batch_size, channels,
|
||||
dtype=x.dtype, device=x.device)
|
||||
u1 = self.encoder_unmasked_dim[i]
|
||||
u2 = u1 + (channels - u1) // 2
|
||||
feature_mask[:, :, u1:u2] *= frame_mask[..., 0:1]
|
||||
feature_mask[:, :, u2:] *= frame_mask[..., 1:2]
|
||||
feature_mask[:, :, u1:] *= frame_mask
|
||||
feature_masks.append(feature_mask)
|
||||
|
||||
return feature_masks
|
||||
@ -547,7 +526,8 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
||||
ff2_skip_rate: FloatLike = 0.01,
|
||||
ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
|
||||
ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
|
||||
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0),
|
||||
bypass_max: FloatLike = 1.0,
|
||||
) -> None:
|
||||
@ -565,6 +545,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
# ff2_skip_rate is to prevent the ff2 module from having output that's too big
|
||||
# compared to its residual.
|
||||
self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
|
||||
self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
|
||||
|
||||
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
|
||||
# ever becoming zero.
|
||||
@ -578,7 +559,10 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
dropout=0.0,
|
||||
)
|
||||
|
||||
self.self_attn = SelfAttention(embed_dim, num_heads,
|
||||
self.self_attn1 = SelfAttention(embed_dim, num_heads,
|
||||
value_head_dim)
|
||||
|
||||
self.self_attn2 = SelfAttention(embed_dim, num_heads,
|
||||
value_head_dim)
|
||||
|
||||
self.feed_forward1 = FeedforwardModule(embed_dim,
|
||||
@ -586,16 +570,24 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
dropout)
|
||||
|
||||
self.feed_forward2 = FeedforwardModule(embed_dim,
|
||||
feedforward_dim,
|
||||
dropout)
|
||||
|
||||
self.feed_forward3 = FeedforwardModule(embed_dim,
|
||||
(feedforward_dim * 5) // 4,
|
||||
dropout)
|
||||
|
||||
self.nonlin_attention = NonlinAttention(embed_dim,
|
||||
hidden_channels=3 * embed_dim // 4)
|
||||
|
||||
self.conv_module = ConvolutionModule(embed_dim,
|
||||
self.conv_module1 = ConvolutionModule(embed_dim,
|
||||
cnn_module_kernel,
|
||||
causal=causal)
|
||||
|
||||
self.conv_module2 = ConvolutionModule(embed_dim,
|
||||
cnn_module_kernel,
|
||||
causal=causal)
|
||||
|
||||
|
||||
#self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
||||
|
||||
@ -617,14 +609,6 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
prob=0.05, # out of concern for memory usage
|
||||
)
|
||||
|
||||
# balancer for output of AttentionSqueezeModule
|
||||
self.balancer_as = Balancer(
|
||||
embed_dim, channel_dim=-1,
|
||||
min_positive=0.3, max_positive=0.7,
|
||||
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
|
||||
prob=0.05, # out of concern for memory usage
|
||||
)
|
||||
|
||||
# balancer for output of feedforward2, prevent it from staying too
|
||||
# small. give this a very small probability, even at the start of
|
||||
# training, it's to fix a rare problem and it's OK to fix it slowly.
|
||||
@ -636,10 +620,19 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
prob=0.05,
|
||||
)
|
||||
|
||||
self.balancer_ff3 = Balancer(
|
||||
embed_dim, channel_dim=-1,
|
||||
min_positive=0.3, max_positive=0.7,
|
||||
min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0),
|
||||
max_abs=4.0,
|
||||
prob=0.05,
|
||||
)
|
||||
|
||||
self.whiten = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(4.0, ratio=3.0),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
self.balancer2 = Balancer(
|
||||
embed_dim, channel_dim=-1,
|
||||
min_positive=0.45, max_positive=0.55,
|
||||
@ -647,9 +640,6 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def remove_attention_weights(self):
|
||||
self.self_attn_weights = None
|
||||
|
||||
def get_bypass_scale(self, batch_size: int):
|
||||
# returns bypass-scale of shape (num_channels,),
|
||||
# or (batch_size, num_channels,). This is actually the
|
||||
@ -696,8 +686,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
chunk_size: int = -1,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
attn_weights: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> Tensor:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
Args:
|
||||
@ -713,8 +702,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
masked position. May be None.
|
||||
|
||||
Returns:
|
||||
(x, attn_weights) where x has the same shape as src, and attn_weights are of
|
||||
shape (num_heads, batch_size, seq_len, seq_len).
|
||||
A tensor which has the same shape as src
|
||||
"""
|
||||
src_orig = src
|
||||
|
||||
@ -722,24 +710,17 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
|
||||
|
||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||
if self.self_attn_weights is not None:
|
||||
attn_weights = self.self_attn_weights(
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=attn_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
attn_weights = self.self_attn_weights(
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=attn_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
# else rely on the ones passed in
|
||||
|
||||
# use different heads for nonlin_attention and attention_squeeze, depending
|
||||
# whether this module has its on self_attn_weights submodule or is borrowing
|
||||
# attention weights from another one.
|
||||
head_offset = 0 if self.self_attn_weights is not None else 2
|
||||
|
||||
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
||||
|
||||
if True:
|
||||
selected_attn_weights = attn_weights[head_offset:head_offset+2]
|
||||
selected_attn_weights = attn_weights[0:2]
|
||||
if random.random() < float(self.const_attention_rate):
|
||||
# Make attention weights constant. The intention is to
|
||||
# encourage these modules to do something similar to an
|
||||
@ -753,21 +734,38 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
|
||||
na = self.balancer_na(self.nonlin_attention(src,
|
||||
selected_attn_weights[0:1]))
|
||||
|
||||
src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask)
|
||||
|
||||
src = src + self.feed_forward1(src)
|
||||
|
||||
self_attn = self.self_attn(
|
||||
self_attn = self.self_attn1(
|
||||
src, attn_weights)
|
||||
|
||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||
|
||||
src = src + self.sequence_dropout(self.conv_module(src, chunk_size=chunk_size,
|
||||
src_key_padding_mask=src_key_padding_mask),
|
||||
src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size,
|
||||
src_key_padding_mask=src_key_padding_mask),
|
||||
float(self.conv_skip_rate))
|
||||
|
||||
src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)),
|
||||
float(self.ff2_skip_rate))
|
||||
|
||||
self_attn = self.self_attn2(
|
||||
src, attn_weights)
|
||||
|
||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||
|
||||
src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size,
|
||||
src_key_padding_mask=src_key_padding_mask),
|
||||
float(self.conv_skip_rate))
|
||||
|
||||
|
||||
src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)),
|
||||
float(self.ff3_skip_rate))
|
||||
|
||||
|
||||
|
||||
src = self.balancer1(src)
|
||||
src = self.norm(src)
|
||||
|
||||
@ -779,7 +777,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
src = self.balancer2(src)
|
||||
src = self.whiten(src)
|
||||
|
||||
return src, attn_weights
|
||||
return src
|
||||
|
||||
class Zipformer2Encoder(nn.Module):
|
||||
r"""Zipformer2Encoder is a stack of N encoder layers
|
||||
@ -805,7 +803,6 @@ class Zipformer2Encoder(nn.Module):
|
||||
warmup_end: float,
|
||||
initial_layerdrop_rate: float = 0.5,
|
||||
final_layerdrop_rate: float = 0.05,
|
||||
attention_share_layers: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15,
|
||||
@ -827,8 +824,6 @@ class Zipformer2Encoder(nn.Module):
|
||||
(cur_end, final_layerdrop_rate),
|
||||
default=0.0)
|
||||
cur_begin = cur_end
|
||||
if i % attention_share_layers != 0:
|
||||
self.layers[i].remove_attention_weights()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -860,16 +855,13 @@ class Zipformer2Encoder(nn.Module):
|
||||
|
||||
output = output * feature_mask
|
||||
|
||||
attn_weights = None
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
output, attn_weights = mod(
|
||||
output = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
chunk_size=chunk_size,
|
||||
attn_mask=attn_mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
attn_weights=attn_weights,
|
||||
)
|
||||
|
||||
output = output * feature_mask
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user