diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index 7f4d000fc..239234e2e 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -96,6 +96,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) +from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, DecodingResults, @@ -167,6 +168,13 @@ def get_parser(): help="Path to the BPE model", ) + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + parser.add_argument( "--decoding-method", type=str, @@ -286,6 +294,8 @@ def decode_one_batch( ) encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + if isinstance(encoder_out, list): + encoder_out = encoder_out[-1] # the last item is final output hyps = [] if params.decoding_method == "fast_beam_search": @@ -345,12 +355,10 @@ def decode_one_batch( res = DecodingResults(hyps=tokens, timestamps=timestamps) hyps, timestamps = parse_hyp_and_timestamp( - decoding_method=params.decoding_method, res=res, sp=sp, subsampling_factor=params.subsampling_factor, frame_shift_ms=params.frame_shift_ms, - word_table=word_table, ) if params.decoding_method == "greedy_search": @@ -533,6 +541,7 @@ def main(): args = parser.parse_args() args.exp_dir = Path(args.exp_dir) + import pdb; pdb.set_trace() params = get_params() params.update(vars(args)) @@ -669,6 +678,9 @@ def main(): else: decoding_graph = None + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}")