Fix decoding.

This commit is contained in:
Fangjun Kuang 2022-01-23 08:51:53 +08:00
parent ce5670f39e
commit da98aa1b8e

View File

@ -51,6 +51,7 @@ from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from model import Transducer from model import Transducer
from transformer import Transformer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
@ -128,6 +129,13 @@ def get_parser():
help="Maximum number of symbols per frame", help="Maximum number of symbols per frame",
) )
parser.add_argument(
"--encoder-type",
type=str,
default="conformer",
help="Type of the encoder. Valid values are: conformer and transformer",
)
return parser return parser
@ -150,8 +158,16 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer if params.encoder_type == "conformer":
encoder = Conformer( Encoder = Conformer
elif params.encoder_type == "transformer":
Encoder = Transformer
else:
raise ValueError(
f"Unsupported encoder type: {params.encoder_type}"
"\nPlease use conformer or transformer"
)
encoder = Encoder(
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,