diff --git a/egs/librispeech/ASR/incremental_transf/.model.py.swp b/egs/librispeech/ASR/incremental_transf/.model.py.swp index 51adf84c4..bd115db11 100644 Binary files a/egs/librispeech/ASR/incremental_transf/.model.py.swp and b/egs/librispeech/ASR/incremental_transf/.model.py.swp differ diff --git a/egs/librispeech/ASR/incremental_transf/model.py b/egs/librispeech/ASR/incremental_transf/model.py index 6083d529b..36326a044 100644 --- a/egs/librispeech/ASR/incremental_transf/model.py +++ b/egs/librispeech/ASR/incremental_transf/model.py @@ -219,20 +219,17 @@ class Interformer(nn.Module): ): """ Args: - encoder: + pt_encoder: It is the transcription network in the paper. Its accepts two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). It returns two tensors: `logits` of shape (N, T, encoder_dm) and `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). - It should contain one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, encoder_dim) and - (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output - contains unnormalized probs, i.e., not processed by log-softmax. + inter_encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + """ super().__init__() assert isinstance(encoder, EncoderInterface), type(encoder)