From a37408f663eb7392e6a7ea7937ff39be9c94501f Mon Sep 17 00:00:00 2001 From: jinzr Date: Thu, 9 Nov 2023 11:57:49 +0800 Subject: [PATCH] Revert "Update decode.py" This reverts commit 73e1237c2d5842ab0b0d3b5ab474c948fd8ff019. --- egs/multi_zh-hans/ASR/zipformer/decode.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py index acb70e388..2d3510fc1 100755 --- a/egs/multi_zh-hans/ASR/zipformer/decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/decode.py @@ -125,7 +125,7 @@ from lhotse.cut import Cut from multi_dataset import MultiDataset from train import add_model_arguments, get_model, get_params -from icefall import ContextGraph, LmScorer, NgramLm, tokenize_by_CJK_char +from icefall import ContextGraph, LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -462,7 +462,7 @@ def decode_one_batch( max_states=params.max_states, ) for hyp in sp.decode(hyp_tokens): - hyps.append(tokenize_by_CJK_char(hyp).split()) + hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -490,7 +490,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, ) for hyp in sp.decode(hyp_tokens): - hyps.append(tokenize_by_CJK_char(hyp).split()) + hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_oracle": hyp_tokens = fast_beam_search_nbest_oracle( model=model, @@ -505,7 +505,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, ) for hyp in sp.decode(hyp_tokens): - hyps.append(tokenize_by_CJK_char(hyp).split()) + hyps.append(hyp.split()) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, @@ -513,7 +513,7 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, ) for hyp in sp.decode(hyp_tokens): - hyps.append(tokenize_by_CJK_char(hyp).split()) + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -523,7 +523,7 @@ def decode_one_batch( context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): - hyps.append(tokenize_by_CJK_char(hyp).split()) + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": hyp_tokens = modified_beam_search_lm_shallow_fusion( model=model, @@ -533,7 +533,7 @@ def decode_one_batch( LM=LM, ) for hyp in sp.decode(hyp_tokens): - hyps.append(tokenize_by_CJK_char(hyp).split()) + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_LODR": hyp_tokens = modified_beam_search_LODR( model=model, @@ -546,7 +546,7 @@ def decode_one_batch( context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): - hyps.append(tokenize_by_CJK_char(hyp).split()) + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_lm_rescore": lm_scale_list = [0.01 * i for i in range(10, 50)] ans_dict = modified_beam_search_lm_rescore( @@ -616,7 +616,7 @@ def decode_one_batch( ans = dict() assert ans_dict is not None for key, hyps in ans_dict.items(): - hyps = [tokenize_by_CJK_char(sp.decode(hyp)).split() for hyp in hyps] + hyps = [sp.decode(hyp).split() for hyp in hyps] ans[f"{prefix}_{key}"] = hyps return ans else: @@ -678,7 +678,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] - texts = [tokenize_by_CJK_char(text).split() for text in texts] + texts = [list(str(text).replace(" ", "")) for text in texts] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch(