diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 83bcc3f3e..a81777353 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -186,7 +186,7 @@ class ConformerEncoderLayer(nn.Module): self.balancer = ActivationBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55, - max_abs=10.0) + max_abs=6.0) self.dropout = nn.Dropout(dropout) @@ -198,7 +198,6 @@ class ConformerEncoderLayer(nn.Module): src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, - position: float = 0.0 ) -> Tensor: """ Pass the input through the encoder layer. @@ -208,11 +207,10 @@ class ConformerEncoderLayer(nn.Module): pos_emb: Positional embedding tensor (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - warmup: controls selective activation of layers; if < 1.0, it's possible that - not all modules will be included. - position: the position of this module in the encoder stack (relates to - warmup); a value 0 <= position < 1.0. - + warmup: controls selective activation of layers; if < 0.5, it's possible that + not all modules will be included. Actually we add the + feed_forward_macaron and self_attn modules at warmup=0.0 + and the conv_module and feed_forward at warmup=0.5. Shape: src: (S, N, E). @@ -223,7 +221,7 @@ class ConformerEncoderLayer(nn.Module): """ # macaron style feed forward module src = torch.add(src, self.dropout(self.feed_forward_macaron(src)), - alpha=(0.0 if warmup < 0.2 * (position + 1) else 1.0)) + alpha=(0.0 if warmup < 0.0 else 1.0)) # multi-headed self-attention module @@ -236,15 +234,15 @@ class ConformerEncoderLayer(nn.Module): key_padding_mask=src_key_padding_mask, )[0] src = torch.add(src, self.dropout(src_att), - alpha=(0.0 if warmup < 0.2 * (position + 2) else 1.0)) + alpha=(0.0 if warmup < 0.0 else 1.0)) # convolution module src = torch.add(src, self.dropout(self.conv_module(src)), - alpha=(0.0 if warmup < 0.2 * (position + 3) else 1.0)) + alpha=(0.0 if warmup < 0.5 else 1.0)) # feed forward module src = torch.add(src, self.dropout(self.feed_forward(src)), - alpha=(0.0 if warmup < 0.2 * (position + 4) else 1.0)) + alpha=(0.0 if warmup < 0.5 else 1.0)) src = self.norm_final(self.balancer(src)) @@ -311,8 +309,7 @@ class ConformerEncoder(nn.Module): pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - position=(i / num_layers), + warmup=warmup-0.5*(i / num_layers) ) return output