Simplify the warmup code; max_abs 10->6

This commit is contained in:
Daniel Povey 2022-03-24 15:06:06 +08:00
parent aab72bc2a5
commit 1f548548d2

View File

@ -186,7 +186,7 @@ class ConformerEncoderLayer(nn.Module):
self.balancer = ActivationBalancer(channel_dim=-1,
min_positive=0.45,
max_positive=0.55,
max_abs=10.0)
max_abs=6.0)
self.dropout = nn.Dropout(dropout)
@ -198,7 +198,6 @@ class ConformerEncoderLayer(nn.Module):
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
position: float = 0.0
) -> Tensor:
"""
Pass the input through the encoder layer.
@ -208,11 +207,10 @@ class ConformerEncoderLayer(nn.Module):
pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective activation of layers; if < 1.0, it's possible that
not all modules will be included.
position: the position of this module in the encoder stack (relates to
warmup); a value 0 <= position < 1.0.
warmup: controls selective activation of layers; if < 0.5, it's possible that
not all modules will be included. Actually we add the
feed_forward_macaron and self_attn modules at warmup=0.0
and the conv_module and feed_forward at warmup=0.5.
Shape:
src: (S, N, E).
@ -223,7 +221,7 @@ class ConformerEncoderLayer(nn.Module):
"""
# macaron style feed forward module
src = torch.add(src, self.dropout(self.feed_forward_macaron(src)),
alpha=(0.0 if warmup < 0.2 * (position + 1) else 1.0))
alpha=(0.0 if warmup < 0.0 else 1.0))
# multi-headed self-attention module
@ -236,15 +234,15 @@ class ConformerEncoderLayer(nn.Module):
key_padding_mask=src_key_padding_mask,
)[0]
src = torch.add(src, self.dropout(src_att),
alpha=(0.0 if warmup < 0.2 * (position + 2) else 1.0))
alpha=(0.0 if warmup < 0.0 else 1.0))
# convolution module
src = torch.add(src, self.dropout(self.conv_module(src)),
alpha=(0.0 if warmup < 0.2 * (position + 3) else 1.0))
alpha=(0.0 if warmup < 0.5 else 1.0))
# feed forward module
src = torch.add(src, self.dropout(self.feed_forward(src)),
alpha=(0.0 if warmup < 0.2 * (position + 4) else 1.0))
alpha=(0.0 if warmup < 0.5 else 1.0))
src = self.norm_final(self.balancer(src))
@ -311,8 +309,7 @@ class ConformerEncoder(nn.Module):
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
position=(i / num_layers),
warmup=warmup-0.5*(i / num_layers)
)
return output