diff --git a/egs/iwslt22_ta/ST/zipformer/decode.py b/egs/iwslt22_ta/ST/zipformer/decode.py index 265491f0c..b13a72f1b 100755 --- a/egs/iwslt22_ta/ST/zipformer/decode.py +++ b/egs/iwslt22_ta/ST/zipformer/decode.py @@ -72,7 +72,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from train_st import add_model_arguments, get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -529,7 +529,7 @@ def decode_dataset( word_table=word_table, batch=batch, ) - #breakpoint() + for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) @@ -540,7 +540,7 @@ def decode_dataset( this_batch.append((cut_id, ref_words, ref_words_tgt, hyp_words)) results[name].extend(this_batch) - #breakpoint() + num_cuts += len(texts) if batch_idx % log_interval == 0: