diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 64752b9a0..a1226f712 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -47,7 +47,7 @@ class Joiner(nn.Module): Return a tensor of shape (N, T, s_range, C). """ assert encoder_out.ndim == decoder_out.ndim == 4 - assert encoder_out.shape == decoder_out.shape + assert encoder_out.shape[:-1] == decoder_out.shape[:-1] logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)