mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
9c3bcd0bf6
commit
d1fffb9c5e
Binary file not shown.
@ -219,20 +219,17 @@ class Interformer(nn.Module):
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
encoder:
|
pt_encoder:
|
||||||
It is the transcription network in the paper. Its accepts
|
It is the transcription network in the paper. Its accepts
|
||||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||||
`logit_lens` of shape (N,).
|
`logit_lens` of shape (N,).
|
||||||
decoder:
|
inter_encoder:
|
||||||
It is the prediction network in the paper. Its input shape
|
It is the transcription network in the paper. Its accepts
|
||||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||||
It should contain one attribute: `blank_id`.
|
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||||
joiner:
|
`logit_lens` of shape (N,).
|
||||||
It has two inputs with shapes: (N, T, encoder_dim) and
|
|
||||||
(N, U, decoder_dim).
|
|
||||||
Its output shape is (N, T, U, vocab_size). Note that its output
|
|
||||||
contains unnormalized probs, i.e., not processed by log-softmax.
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user