diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 6f7231f4b..e2ec2e1cf 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -87,10 +87,17 @@ class Conformer(EncoderInterface): layer_dropout, cnn_module_kernel, ) + # aux_layers from 1/3 self.encoder = ConformerEncoder( encoder_layer, num_encoder_layers, - aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)), + aux_layers=list( + range( + num_encoder_layers // 3, + num_encoder_layers - 1, + aux_layer_period, + ) + ), ) def forward( @@ -296,10 +303,10 @@ class ConformerEncoder(nn.Module): assert num_layers - 1 not in aux_layers self.aux_layers = set(aux_layers + [num_layers - 1]) - num_channels = encoder_layer.norm_final.num_channels + # num_channels = encoder_layer.norm_final.num_channels self.combiner = RandomCombine( num_inputs=len(self.aux_layers), - num_channels=num_channels, + # num_channels=num_channels, final_weight=0.5, pure_prob=0.333, stddev=2.0, @@ -1073,7 +1080,7 @@ class RandomCombine(nn.Module): def __init__( self, num_inputs: int, - num_channels: int, + # num_channels: int, final_weight: float = 0.5, pure_prob: float = 0.5, stddev: float = 2.0, @@ -1116,12 +1123,12 @@ class RandomCombine(nn.Module): assert 0 < final_weight < 1, final_weight assert num_inputs >= 1 - self.linear = nn.ModuleList( - [ - nn.Linear(num_channels, num_channels, bias=True) - for _ in range(num_inputs - 1) - ] - ) + # self.linear = nn.ModuleList( + # [ + # nn.Linear(num_channels, num_channels, bias=True) + # for _ in range(num_inputs - 1) + # ] + # ) self.num_inputs = num_inputs self.final_weight = final_weight @@ -1135,12 +1142,13 @@ class RandomCombine(nn.Module): .log() .item() ) - self._reset_parameters() - def _reset_parameters(self): - for i in range(len(self.linear)): - nn.init.eye_(self.linear[i].weight) - nn.init.constant_(self.linear[i].bias, 0.0) + # self._reset_parameters() + + # def _reset_parameters(self): + # for i in range(len(self.linear)): + # nn.init.eye_(self.linear[i].weight) + # nn.init.constant_(self.linear[i].bias, 0.0) def forward(self, inputs: List[Tensor]) -> Tensor: """Forward function. @@ -1163,7 +1171,8 @@ class RandomCombine(nn.Module): mod_inputs = [] for i in range(num_inputs - 1): - mod_inputs.append(self.linear[i](inputs[i])) + # mod_inputs.append(self.linear[i](inputs[i])) + mod_inputs.append(inputs[i]) mod_inputs.append(inputs[num_inputs - 1]) ndim = inputs[0].ndim