Modify the joiner network for torchaudio's RNN-T loss.

This commit is contained in:
Fangjun Kuang 2022-04-14 11:37:48 +08:00
parent 1ca7f35a1c
commit 38279d4b24

View File

@ -30,8 +30,6 @@ class Joiner(nn.Module):
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
encoder_out_len: torch.Tensor,
decoder_out_len: torch.Tensor,
) -> torch.Tensor:
"""
Args:
@ -39,40 +37,17 @@ class Joiner(nn.Module):
Output from the encoder. Its shape is (N, T, self.input_dim).
decoder_out:
Output from the decoder. Its shape is (N, U, self.input_dim).
encoder_out_len:
A 1-D tensor of shape (N,) containing valid number of frames
before padding in `encoder_out`.
decoder_out_len:
A 1-D tensor of shape (N,) containing valid number of frames
before padding in `decoder_out`.
Returns:
Return a tensor of shape (sum_all_TU, self.output_dim).
Return a tensor of shape (N, T, U, self.output_dim).
"""
assert encoder_out.ndim == decoder_out.ndim == 3
assert encoder_out.size(0) == decoder_out.size(0)
assert encoder_out.size(2) == self.input_dim
assert decoder_out.size(2) == self.input_dim
N = encoder_out.size(0)
encoder_out_len = encoder_out_len.tolist()
decoder_out_len = decoder_out_len.tolist()
encoder_out_list = [
encoder_out[i, : encoder_out_len[i], :] for i in range(N)
]
decoder_out_list = [
decoder_out[i, : decoder_out_len[i], :] for i in range(N)
]
x = [
e.unsqueeze(1) + d.unsqueeze(0)
for e, d in zip(encoder_out_list, decoder_out_list)
]
x = [p.reshape(-1, self.input_dim) for p in x]
x = torch.cat(x)
encoder_out = encoder_out.unsqueeze(2) # (N, T, 1, C)
decoder_out = decoder_out.unsqueeze(1) # (N, 1, U, C)
x = encoder_out + decoder_out # (N, T, U, C)
activations = torch.tanh(x)