mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Simplify the warmup code; max_abs 10->6
This commit is contained in:
parent
aab72bc2a5
commit
1f548548d2
@ -186,7 +186,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self.balancer = ActivationBalancer(channel_dim=-1,
|
self.balancer = ActivationBalancer(channel_dim=-1,
|
||||||
min_positive=0.45,
|
min_positive=0.45,
|
||||||
max_positive=0.55,
|
max_positive=0.55,
|
||||||
max_abs=10.0)
|
max_abs=6.0)
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
@ -198,7 +198,6 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src_mask: Optional[Tensor] = None,
|
src_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
position: float = 0.0
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Pass the input through the encoder layer.
|
Pass the input through the encoder layer.
|
||||||
@ -208,11 +207,10 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
pos_emb: Positional embedding tensor (required).
|
pos_emb: Positional embedding tensor (required).
|
||||||
src_mask: the mask for the src sequence (optional).
|
src_mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
warmup: controls selective activation of layers; if < 1.0, it's possible that
|
warmup: controls selective activation of layers; if < 0.5, it's possible that
|
||||||
not all modules will be included.
|
not all modules will be included. Actually we add the
|
||||||
position: the position of this module in the encoder stack (relates to
|
feed_forward_macaron and self_attn modules at warmup=0.0
|
||||||
warmup); a value 0 <= position < 1.0.
|
and the conv_module and feed_forward at warmup=0.5.
|
||||||
|
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
src: (S, N, E).
|
src: (S, N, E).
|
||||||
@ -223,7 +221,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# macaron style feed forward module
|
# macaron style feed forward module
|
||||||
src = torch.add(src, self.dropout(self.feed_forward_macaron(src)),
|
src = torch.add(src, self.dropout(self.feed_forward_macaron(src)),
|
||||||
alpha=(0.0 if warmup < 0.2 * (position + 1) else 1.0))
|
alpha=(0.0 if warmup < 0.0 else 1.0))
|
||||||
|
|
||||||
|
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
@ -236,15 +234,15 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
key_padding_mask=src_key_padding_mask,
|
key_padding_mask=src_key_padding_mask,
|
||||||
)[0]
|
)[0]
|
||||||
src = torch.add(src, self.dropout(src_att),
|
src = torch.add(src, self.dropout(src_att),
|
||||||
alpha=(0.0 if warmup < 0.2 * (position + 2) else 1.0))
|
alpha=(0.0 if warmup < 0.0 else 1.0))
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
src = torch.add(src, self.dropout(self.conv_module(src)),
|
src = torch.add(src, self.dropout(self.conv_module(src)),
|
||||||
alpha=(0.0 if warmup < 0.2 * (position + 3) else 1.0))
|
alpha=(0.0 if warmup < 0.5 else 1.0))
|
||||||
|
|
||||||
# feed forward module
|
# feed forward module
|
||||||
src = torch.add(src, self.dropout(self.feed_forward(src)),
|
src = torch.add(src, self.dropout(self.feed_forward(src)),
|
||||||
alpha=(0.0 if warmup < 0.2 * (position + 4) else 1.0))
|
alpha=(0.0 if warmup < 0.5 else 1.0))
|
||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
@ -311,8 +309,7 @@ class ConformerEncoder(nn.Module):
|
|||||||
pos_emb,
|
pos_emb,
|
||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
warmup=warmup,
|
warmup=warmup-0.5*(i / num_layers)
|
||||||
position=(i / num_layers),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
Loading…
x
Reference in New Issue
Block a user