mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Fix decoding.
This commit is contained in:
parent
ce5670f39e
commit
da98aa1b8e
@ -51,6 +51,7 @@ from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from model import Transducer
|
||||
from transformer import Transformer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
@ -128,6 +129,13 @@ def get_parser():
|
||||
help="Maximum number of symbols per frame",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-type",
|
||||
type=str,
|
||||
default="conformer",
|
||||
help="Type of the encoder. Valid values are: conformer and transformer",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -150,8 +158,16 @@ def get_params() -> AttributeDict:
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict):
|
||||
# TODO: We can add an option to switch between Conformer and Transformer
|
||||
encoder = Conformer(
|
||||
if params.encoder_type == "conformer":
|
||||
Encoder = Conformer
|
||||
elif params.encoder_type == "transformer":
|
||||
Encoder = Transformer
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported encoder type: {params.encoder_type}"
|
||||
"\nPlease use conformer or transformer"
|
||||
)
|
||||
encoder = Encoder(
|
||||
num_features=params.feature_dim,
|
||||
output_dim=params.encoder_out_dim,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
|
Loading…
x
Reference in New Issue
Block a user