mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 22:54:18 +00:00
Modify the joiner network for torchaudio's RNN-T loss.
This commit is contained in:
parent
1ca7f35a1c
commit
38279d4b24
@ -30,8 +30,6 @@ class Joiner(nn.Module):
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
encoder_out_len: torch.Tensor,
|
||||
decoder_out_len: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -39,40 +37,17 @@ class Joiner(nn.Module):
|
||||
Output from the encoder. Its shape is (N, T, self.input_dim).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, U, self.input_dim).
|
||||
encoder_out_len:
|
||||
A 1-D tensor of shape (N,) containing valid number of frames
|
||||
before padding in `encoder_out`.
|
||||
decoder_out_len:
|
||||
A 1-D tensor of shape (N,) containing valid number of frames
|
||||
before padding in `decoder_out`.
|
||||
Returns:
|
||||
Return a tensor of shape (sum_all_TU, self.output_dim).
|
||||
Return a tensor of shape (N, T, U, self.output_dim).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim == 3
|
||||
assert encoder_out.size(0) == decoder_out.size(0)
|
||||
assert encoder_out.size(2) == self.input_dim
|
||||
assert decoder_out.size(2) == self.input_dim
|
||||
|
||||
N = encoder_out.size(0)
|
||||
|
||||
encoder_out_len = encoder_out_len.tolist()
|
||||
decoder_out_len = decoder_out_len.tolist()
|
||||
|
||||
encoder_out_list = [
|
||||
encoder_out[i, : encoder_out_len[i], :] for i in range(N)
|
||||
]
|
||||
|
||||
decoder_out_list = [
|
||||
decoder_out[i, : decoder_out_len[i], :] for i in range(N)
|
||||
]
|
||||
|
||||
x = [
|
||||
e.unsqueeze(1) + d.unsqueeze(0)
|
||||
for e, d in zip(encoder_out_list, decoder_out_list)
|
||||
]
|
||||
|
||||
x = [p.reshape(-1, self.input_dim) for p in x]
|
||||
x = torch.cat(x)
|
||||
encoder_out = encoder_out.unsqueeze(2) # (N, T, 1, C)
|
||||
decoder_out = decoder_out.unsqueeze(1) # (N, 1, U, C)
|
||||
x = encoder_out + decoder_out # (N, T, U, C)
|
||||
|
||||
activations = torch.tanh(x)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user