Modify the joiner network for torchaudio's RNN-T loss.

This commit is contained in:
Fangjun Kuang 2022-04-14 11:37:48 +08:00
parent 1ca7f35a1c
commit 38279d4b24

View File

@ -30,8 +30,6 @@ class Joiner(nn.Module):
self, self,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
decoder_out: torch.Tensor, decoder_out: torch.Tensor,
encoder_out_len: torch.Tensor,
decoder_out_len: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -39,40 +37,17 @@ class Joiner(nn.Module):
Output from the encoder. Its shape is (N, T, self.input_dim). Output from the encoder. Its shape is (N, T, self.input_dim).
decoder_out: decoder_out:
Output from the decoder. Its shape is (N, U, self.input_dim). Output from the decoder. Its shape is (N, U, self.input_dim).
encoder_out_len:
A 1-D tensor of shape (N,) containing valid number of frames
before padding in `encoder_out`.
decoder_out_len:
A 1-D tensor of shape (N,) containing valid number of frames
before padding in `decoder_out`.
Returns: Returns:
Return a tensor of shape (sum_all_TU, self.output_dim). Return a tensor of shape (N, T, U, self.output_dim).
""" """
assert encoder_out.ndim == decoder_out.ndim == 3 assert encoder_out.ndim == decoder_out.ndim == 3
assert encoder_out.size(0) == decoder_out.size(0) assert encoder_out.size(0) == decoder_out.size(0)
assert encoder_out.size(2) == self.input_dim assert encoder_out.size(2) == self.input_dim
assert decoder_out.size(2) == self.input_dim assert decoder_out.size(2) == self.input_dim
N = encoder_out.size(0) encoder_out = encoder_out.unsqueeze(2) # (N, T, 1, C)
decoder_out = decoder_out.unsqueeze(1) # (N, 1, U, C)
encoder_out_len = encoder_out_len.tolist() x = encoder_out + decoder_out # (N, T, U, C)
decoder_out_len = decoder_out_len.tolist()
encoder_out_list = [
encoder_out[i, : encoder_out_len[i], :] for i in range(N)
]
decoder_out_list = [
decoder_out[i, : decoder_out_len[i], :] for i in range(N)
]
x = [
e.unsqueeze(1) + d.unsqueeze(0)
for e, d in zip(encoder_out_list, decoder_out_list)
]
x = [p.reshape(-1, self.input_dim) for p in x]
x = torch.cat(x)
activations = torch.tanh(x) activations = torch.tanh(x)