mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
Modify the joiner network for torchaudio's RNN-T loss.
This commit is contained in:
parent
1ca7f35a1c
commit
38279d4b24
@ -30,8 +30,6 @@ class Joiner(nn.Module):
|
|||||||
self,
|
self,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
decoder_out: torch.Tensor,
|
decoder_out: torch.Tensor,
|
||||||
encoder_out_len: torch.Tensor,
|
|
||||||
decoder_out_len: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -39,40 +37,17 @@ class Joiner(nn.Module):
|
|||||||
Output from the encoder. Its shape is (N, T, self.input_dim).
|
Output from the encoder. Its shape is (N, T, self.input_dim).
|
||||||
decoder_out:
|
decoder_out:
|
||||||
Output from the decoder. Its shape is (N, U, self.input_dim).
|
Output from the decoder. Its shape is (N, U, self.input_dim).
|
||||||
encoder_out_len:
|
|
||||||
A 1-D tensor of shape (N,) containing valid number of frames
|
|
||||||
before padding in `encoder_out`.
|
|
||||||
decoder_out_len:
|
|
||||||
A 1-D tensor of shape (N,) containing valid number of frames
|
|
||||||
before padding in `decoder_out`.
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (sum_all_TU, self.output_dim).
|
Return a tensor of shape (N, T, U, self.output_dim).
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == decoder_out.ndim == 3
|
assert encoder_out.ndim == decoder_out.ndim == 3
|
||||||
assert encoder_out.size(0) == decoder_out.size(0)
|
assert encoder_out.size(0) == decoder_out.size(0)
|
||||||
assert encoder_out.size(2) == self.input_dim
|
assert encoder_out.size(2) == self.input_dim
|
||||||
assert decoder_out.size(2) == self.input_dim
|
assert decoder_out.size(2) == self.input_dim
|
||||||
|
|
||||||
N = encoder_out.size(0)
|
encoder_out = encoder_out.unsqueeze(2) # (N, T, 1, C)
|
||||||
|
decoder_out = decoder_out.unsqueeze(1) # (N, 1, U, C)
|
||||||
encoder_out_len = encoder_out_len.tolist()
|
x = encoder_out + decoder_out # (N, T, U, C)
|
||||||
decoder_out_len = decoder_out_len.tolist()
|
|
||||||
|
|
||||||
encoder_out_list = [
|
|
||||||
encoder_out[i, : encoder_out_len[i], :] for i in range(N)
|
|
||||||
]
|
|
||||||
|
|
||||||
decoder_out_list = [
|
|
||||||
decoder_out[i, : decoder_out_len[i], :] for i in range(N)
|
|
||||||
]
|
|
||||||
|
|
||||||
x = [
|
|
||||||
e.unsqueeze(1) + d.unsqueeze(0)
|
|
||||||
for e, d in zip(encoder_out_list, decoder_out_list)
|
|
||||||
]
|
|
||||||
|
|
||||||
x = [p.reshape(-1, self.input_dim) for p in x]
|
|
||||||
x = torch.cat(x)
|
|
||||||
|
|
||||||
activations = torch.tanh(x)
|
activations = torch.tanh(x)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user