diff --git a/egs/librispeech/ASR/conformer_ctc2/.conformer.py.swp b/egs/librispeech/ASR/conformer_ctc2/.conformer.py.swp index f9f11eb4d..aa55848ac 100644 Binary files a/egs/librispeech/ASR/conformer_ctc2/.conformer.py.swp and b/egs/librispeech/ASR/conformer_ctc2/.conformer.py.swp differ diff --git a/egs/librispeech/ASR/conformer_ctc2/conformer.py b/egs/librispeech/ASR/conformer_ctc2/conformer.py index 6a3739841..49b45ec7f 100644 --- a/egs/librispeech/ASR/conformer_ctc2/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/conformer.py @@ -64,6 +64,7 @@ class Conformer(Transformer): dropout: float = 0.1, layer_dropout: float = 0.075, cnn_module_kernel: int = 31, + group_num: int = 0, ) -> None: super(Conformer, self).__init__( num_features=num_features, @@ -101,12 +102,13 @@ class Conformer(Transformer): cnn_module_kernel, ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - + self.group_num = group_num - self.group_layer_num = int(self.encoder_layers // self.group_num) - self.alpha = nn.Parameter(torch.rand(self.group_num)) - self.sigmoid = nn.Sigmoid() - self.layer_norm = nn.LayerNorm(d_model) + if self.group_num != 0: + self.group_layer_num = int(self.encoder_layers // self.group_num) + self.alpha = nn.Parameter(torch.rand(self.group_num)) + self.sigmoid = nn.Sigmoid() + self.layer_norm = nn.LayerNorm(d_model) def run_encoder( self, @@ -146,8 +148,10 @@ class Conformer(Transformer): x, layer_outputs = self.encoder( x, pos_emb, src_key_padding_mask=mask, warmup=warmup ) # (T, N, C) - + + if self.group_num != 0: # x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + # return x, lengths return x, mask