diff --git a/egs/librispeech/ASR/transducer_stateless2/joiner.py b/egs/librispeech/ASR/transducer_stateless2/joiner.py index b0ba7fd83..e56ba859d 100644 --- a/egs/librispeech/ASR/transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless2/joiner.py @@ -30,8 +30,6 @@ class Joiner(nn.Module): self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, - encoder_out_len: torch.Tensor, - decoder_out_len: torch.Tensor, ) -> torch.Tensor: """ Args: @@ -39,40 +37,17 @@ class Joiner(nn.Module): Output from the encoder. Its shape is (N, T, self.input_dim). decoder_out: Output from the decoder. Its shape is (N, U, self.input_dim). - encoder_out_len: - A 1-D tensor of shape (N,) containing valid number of frames - before padding in `encoder_out`. - decoder_out_len: - A 1-D tensor of shape (N,) containing valid number of frames - before padding in `decoder_out`. Returns: - Return a tensor of shape (sum_all_TU, self.output_dim). + Return a tensor of shape (N, T, U, self.output_dim). """ assert encoder_out.ndim == decoder_out.ndim == 3 assert encoder_out.size(0) == decoder_out.size(0) assert encoder_out.size(2) == self.input_dim assert decoder_out.size(2) == self.input_dim - N = encoder_out.size(0) - - encoder_out_len = encoder_out_len.tolist() - decoder_out_len = decoder_out_len.tolist() - - encoder_out_list = [ - encoder_out[i, : encoder_out_len[i], :] for i in range(N) - ] - - decoder_out_list = [ - decoder_out[i, : decoder_out_len[i], :] for i in range(N) - ] - - x = [ - e.unsqueeze(1) + d.unsqueeze(0) - for e, d in zip(encoder_out_list, decoder_out_list) - ] - - x = [p.reshape(-1, self.input_dim) for p in x] - x = torch.cat(x) + encoder_out = encoder_out.unsqueeze(2) # (N, T, 1, C) + decoder_out = decoder_out.unsqueeze(1) # (N, 1, U, C) + x = encoder_out + decoder_out # (N, T, U, C) activations = torch.tanh(x)