from local

This commit is contained in:
dohe0342 2023-01-09 19:21:21 +09:00
parent 9c3bcd0bf6
commit d1fffb9c5e
2 changed files with 7 additions and 10 deletions

View File

@ -219,20 +219,17 @@ class Interformer(nn.Module):
): ):
""" """
Args: Args:
encoder: pt_encoder:
It is the transcription network in the paper. Its accepts It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,). `logit_lens` of shape (N,).
decoder: inter_encoder:
It is the prediction network in the paper. Its input shape It is the transcription network in the paper. Its accepts
is (N, U) and its output shape is (N, U, decoder_dim). two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It should contain one attribute: `blank_id`. It returns two tensors: `logits` of shape (N, T, encoder_dm) and
joiner: `logit_lens` of shape (N,).
It has two inputs with shapes: (N, T, encoder_dim) and
(N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output
contains unnormalized probs, i.e., not processed by log-softmax.
""" """
super().__init__() super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder) assert isinstance(encoder, EncoderInterface), type(encoder)