diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/.conformer.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/.conformer.py.swp index a8959371d..d0a63c3e4 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/.conformer.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/.conformer.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/.conformer_randomcombine.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/.conformer_randomcombine.py.swp index 626ec86c5..1423504dc 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/.conformer_randomcombine.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/.conformer_randomcombine.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/conformer_randomcombine.py b/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/conformer_randomcombine.py index 530246877..defbdcb6e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/conformer_randomcombine.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_gtrans/conformer_randomcombine.py @@ -136,10 +136,17 @@ class Conformer(EncoderInterface): ) self._init_state: List[torch.Tensor] = [torch.empty(0)] + ''' self.group_size = 12 self.alpha = nn.Parameter(torch.rand(self.group_size)) self.sigmoid = nn.Sigmoid() self.layer_norm = nn.LayerNorm(512) + ''' + self.group_num = group_num + self.group_layer_num = int(self.encoder_layers // self.group_num) + self.alpha = nn.Parameter(torch.rand(self.group_num)) + self.sigmoid = nn.Sigmoid() + self.layer_norm = nn.LayerNorm(d_model) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0