diff --git a/egs/aishell/ASR/conformer_ctc/.transformer.py.swp b/egs/aishell/ASR/conformer_ctc/.transformer.py.swp new file mode 100644 index 000000000..1554aba78 Binary files /dev/null and b/egs/aishell/ASR/conformer_ctc/.transformer.py.swp differ diff --git a/egs/aishell/ASR/conformer_ctc/transformer.py b/egs/aishell/ASR/conformer_ctc/transformer.py index d01fa29ef..1012b0921 100644 --- a/egs/aishell/ASR/conformer_ctc/transformer.py +++ b/egs/aishell/ASR/conformer_ctc/transformer.py @@ -160,6 +160,12 @@ class Transformer(nn.Module): else: self.decoder_criterion = None + self.group_num = group_num + self.group_layer_num = int(self.encoder_layers // self.group_num) + self.alpha = nn.Parameter(torch.rand(self.group_num)) + self.sigmoid = nn.Sigmoid() + self.layer_norm = nn.LayerNorm(d_model) + def forward( self, x: torch.Tensor, supervision: Optional[Supervisions] = None ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: