Simplify the warmup code; max_abs 10->6

This commit is contained in:
Daniel Povey 2022-03-24 15:06:06 +08:00
parent aab72bc2a5
commit 1f548548d2

View File

@ -186,7 +186,7 @@ class ConformerEncoderLayer(nn.Module):
self.balancer = ActivationBalancer(channel_dim=-1, self.balancer = ActivationBalancer(channel_dim=-1,
min_positive=0.45, min_positive=0.45,
max_positive=0.55, max_positive=0.55,
max_abs=10.0) max_abs=6.0)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
@ -198,7 +198,6 @@ class ConformerEncoderLayer(nn.Module):
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0, warmup: float = 1.0,
position: float = 0.0
) -> Tensor: ) -> Tensor:
""" """
Pass the input through the encoder layer. Pass the input through the encoder layer.
@ -208,11 +207,10 @@ class ConformerEncoderLayer(nn.Module):
pos_emb: Positional embedding tensor (required). pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional). src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective activation of layers; if < 1.0, it's possible that warmup: controls selective activation of layers; if < 0.5, it's possible that
not all modules will be included. not all modules will be included. Actually we add the
position: the position of this module in the encoder stack (relates to feed_forward_macaron and self_attn modules at warmup=0.0
warmup); a value 0 <= position < 1.0. and the conv_module and feed_forward at warmup=0.5.
Shape: Shape:
src: (S, N, E). src: (S, N, E).
@ -223,7 +221,7 @@ class ConformerEncoderLayer(nn.Module):
""" """
# macaron style feed forward module # macaron style feed forward module
src = torch.add(src, self.dropout(self.feed_forward_macaron(src)), src = torch.add(src, self.dropout(self.feed_forward_macaron(src)),
alpha=(0.0 if warmup < 0.2 * (position + 1) else 1.0)) alpha=(0.0 if warmup < 0.0 else 1.0))
# multi-headed self-attention module # multi-headed self-attention module
@ -236,15 +234,15 @@ class ConformerEncoderLayer(nn.Module):
key_padding_mask=src_key_padding_mask, key_padding_mask=src_key_padding_mask,
)[0] )[0]
src = torch.add(src, self.dropout(src_att), src = torch.add(src, self.dropout(src_att),
alpha=(0.0 if warmup < 0.2 * (position + 2) else 1.0)) alpha=(0.0 if warmup < 0.0 else 1.0))
# convolution module # convolution module
src = torch.add(src, self.dropout(self.conv_module(src)), src = torch.add(src, self.dropout(self.conv_module(src)),
alpha=(0.0 if warmup < 0.2 * (position + 3) else 1.0)) alpha=(0.0 if warmup < 0.5 else 1.0))
# feed forward module # feed forward module
src = torch.add(src, self.dropout(self.feed_forward(src)), src = torch.add(src, self.dropout(self.feed_forward(src)),
alpha=(0.0 if warmup < 0.2 * (position + 4) else 1.0)) alpha=(0.0 if warmup < 0.5 else 1.0))
src = self.norm_final(self.balancer(src)) src = self.norm_final(self.balancer(src))
@ -311,8 +309,7 @@ class ConformerEncoder(nn.Module):
pos_emb, pos_emb,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
warmup=warmup, warmup=warmup-0.5*(i / num_layers)
position=(i / num_layers),
) )
return output return output