from local

This commit is contained in:
dohe0342 2023-01-09 19:21:21 +09:00
parent 9c3bcd0bf6
commit d1fffb9c5e
2 changed files with 7 additions and 10 deletions

View File

@ -219,20 +219,17 @@ class Interformer(nn.Module):
):
"""
Args:
encoder:
pt_encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and
(N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output
contains unnormalized probs, i.e., not processed by log-softmax.
inter_encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)