mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
9c3bcd0bf6
commit
d1fffb9c5e
Binary file not shown.
@ -219,20 +219,17 @@ class Interformer(nn.Module):
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
pt_encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and
|
||||
(N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output
|
||||
contains unnormalized probs, i.e., not processed by log-softmax.
|
||||
inter_encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user