diff --git a/egs/multi_zh_en/ASR/zipformer/decode.py b/egs/multi_zh_en/ASR/zipformer/decode.py index 6d31178ed..337e5a540 100755 --- a/egs/multi_zh_en/ASR/zipformer/decode.py +++ b/egs/multi_zh_en/ASR/zipformer/decode.py @@ -470,7 +470,6 @@ def decode_one_batch( f"Unsupported decoding method: {params.decoding_method}" ) hyps.append(smart_byte_decode(sp.decode(hyp)).split()) - if params.decoding_method == "greedy_search": return {"greedy_search": hyps} elif "fast_beam_search" in params.decoding_method: @@ -535,7 +534,9 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] - texts = [list(str(text).replace(" ", "")) for text in texts] + texts = [tokenize_by_CJK_char(str(text)).split() for text in texts] + # print(texts) + # exit() cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( @@ -551,8 +552,7 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - hyp_text = "".join(hyp_words) - this_batch.append((cut_id, ref_text, hyp_text)) + this_batch.append((cut_id, ref_text, hyp_words)) results[name].extend(this_batch) diff --git a/egs/multi_zh_en/ASR/zipformer/train.py b/egs/multi_zh_en/ASR/zipformer/train.py index 0853291dc..8318fc1ee 100755 --- a/egs/multi_zh_en/ASR/zipformer/train.py +++ b/egs/multi_zh_en/ASR/zipformer/train.py @@ -1186,7 +1186,7 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - if c.duration < 1.0 or c.duration > 12.0: + if c.duration < 1.0: logging.warning( f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" )