Fix decoding.

This commit is contained in:
Fangjun Kuang 2022-01-23 08:51:53 +08:00
parent ce5670f39e
commit da98aa1b8e

View File

@ -51,6 +51,7 @@ from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from transformer import Transformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
@ -128,6 +129,13 @@ def get_parser():
help="Maximum number of symbols per frame",
)
parser.add_argument(
"--encoder-type",
type=str,
default="conformer",
help="Type of the encoder. Valid values are: conformer and transformer",
)
return parser
@ -150,8 +158,16 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
if params.encoder_type == "conformer":
Encoder = Conformer
elif params.encoder_type == "transformer":
Encoder = Transformer
else:
raise ValueError(
f"Unsupported encoder type: {params.encoder_type}"
"\nPlease use conformer or transformer"
)
encoder = Encoder(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,