Revert "Update decode.py"

This reverts commit 73e1237c2d5842ab0b0d3b5ab474c948fd8ff019.
This commit is contained in:
jinzr 2023-11-09 11:57:49 +08:00
parent 73e1237c2d
commit a37408f663

View File

@ -125,7 +125,7 @@ from lhotse.cut import Cut
from multi_dataset import MultiDataset from multi_dataset import MultiDataset
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph, LmScorer, NgramLm, tokenize_by_CJK_char from icefall import ContextGraph, LmScorer, NgramLm
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -462,7 +462,7 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(tokenize_by_CJK_char(hyp).split()) hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG": elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG( hyp_tokens = fast_beam_search_nbest_LG(
model=model, model=model,
@ -490,7 +490,7 @@ def decode_one_batch(
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(tokenize_by_CJK_char(hyp).split()) hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle": elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle( hyp_tokens = fast_beam_search_nbest_oracle(
model=model, model=model,
@ -505,7 +505,7 @@ def decode_one_batch(
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(tokenize_by_CJK_char(hyp).split()) hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
@ -513,7 +513,7 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(tokenize_by_CJK_char(hyp).split()) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -523,7 +523,7 @@ def decode_one_batch(
context_graph=context_graph, context_graph=context_graph,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(tokenize_by_CJK_char(hyp).split()) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion( hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model, model=model,
@ -533,7 +533,7 @@ def decode_one_batch(
LM=LM, LM=LM,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(tokenize_by_CJK_char(hyp).split()) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_LODR": elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR( hyp_tokens = modified_beam_search_LODR(
model=model, model=model,
@ -546,7 +546,7 @@ def decode_one_batch(
context_graph=context_graph, context_graph=context_graph,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(tokenize_by_CJK_char(hyp).split()) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_rescore": elif params.decoding_method == "modified_beam_search_lm_rescore":
lm_scale_list = [0.01 * i for i in range(10, 50)] lm_scale_list = [0.01 * i for i in range(10, 50)]
ans_dict = modified_beam_search_lm_rescore( ans_dict = modified_beam_search_lm_rescore(
@ -616,7 +616,7 @@ def decode_one_batch(
ans = dict() ans = dict()
assert ans_dict is not None assert ans_dict is not None
for key, hyps in ans_dict.items(): for key, hyps in ans_dict.items():
hyps = [tokenize_by_CJK_char(sp.decode(hyp)).split() for hyp in hyps] hyps = [sp.decode(hyp).split() for hyp in hyps]
ans[f"{prefix}_{key}"] = hyps ans[f"{prefix}_{key}"] = hyps
return ans return ans
else: else:
@ -678,7 +678,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
texts = [tokenize_by_CJK_char(text).split() for text in texts] texts = [list(str(text).replace(" ", "")) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(