diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 03a47927f..528cc48f4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -83,7 +83,7 @@ class Conformer(EncoderInterface): aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) if output_dim == d_model: - self.encoder_output_layer = Identity() + self.encoder_output_layer = nn.Identity() else: self.encoder_output_layer = ScaledLinear(d_model, output_dim, initial_speed=0.5) @@ -936,10 +936,6 @@ class ConvolutionModule(nn.Module): return x.permute(2, 0, 1) -class Identity(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - return x - class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/4 length).