diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 2babb94bc..66334688b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -306,7 +306,6 @@ class ConformerEncoder(nn.Module): num_channels = encoder_layer.norm_final.num_channels self.combiner = RandomCombine( num_inputs=len(self.aux_layers), - num_channels=num_channels, final_weight=0.5, pure_prob=0.333, stddev=2.0, @@ -1081,7 +1080,6 @@ class RandomCombine(nn.Module): def __init__( self, num_inputs: int, - num_channels: int, final_weight: float = 0.5, pure_prob: float = 0.5, stddev: float = 2.0, @@ -1092,8 +1090,6 @@ class RandomCombine(nn.Module): The number of tensor inputs, which equals the number of layers' outputs that are fed into this module. E.g. in an 18-layer neural net if we output layers 16, 12, 18, num_inputs would be 3. - num_channels: - The number of channels on the input, e.g. 512. final_weight: The amount of weight or probability we assign to the final layer when randomly choosing layers or when choosing @@ -1124,13 +1120,6 @@ class RandomCombine(nn.Module): assert 0 < final_weight < 1, final_weight assert num_inputs >= 1 - self.linear = nn.ModuleList( - [ - nn.Linear(num_channels, num_channels, bias=True) - for _ in range(num_inputs - 1) - ] - ) - self.num_inputs = num_inputs self.final_weight = final_weight self.pure_prob = pure_prob @@ -1143,12 +1132,7 @@ class RandomCombine(nn.Module): .log() .item() ) - self._reset_parameters() - def _reset_parameters(self): - for i in range(len(self.linear)): - nn.init.eye_(self.linear[i].weight) - nn.init.constant_(self.linear[i].bias, 0.0) def forward(self, inputs: List[Tensor]) -> Tensor: """Forward function. @@ -1169,14 +1153,9 @@ class RandomCombine(nn.Module): num_channels = inputs[0].shape[-1] num_frames = inputs[0].numel() // num_channels - mod_inputs = [] - for i in range(num_inputs - 1): - mod_inputs.append(self.linear[i](inputs[i])) - mod_inputs.append(inputs[num_inputs - 1]) - ndim = inputs[0].ndim # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape( + stacked_inputs = torch.stack(inputs, dim=ndim).reshape( (num_frames, num_channels, num_inputs) ) @@ -1290,7 +1269,6 @@ def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): num_channels = 50 m = RandomCombine( num_inputs=num_inputs, - num_channels=num_channels, final_weight=final_weight, pure_prob=pure_prob, stddev=stddev,