Update decode.py

This commit is contained in:
jinzr 2023-11-09 11:50:39 +08:00
parent 16499a5ef6
commit 73e1237c2d

View File

@ -125,7 +125,7 @@ from lhotse.cut import Cut
from multi_dataset import MultiDataset
from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph, LmScorer, NgramLm
from icefall import ContextGraph, LmScorer, NgramLm, tokenize_by_CJK_char
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -462,7 +462,7 @@ def decode_one_batch(
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(tokenize_by_CJK_char(hyp).split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
@ -490,7 +490,7 @@ def decode_one_batch(
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(tokenize_by_CJK_char(hyp).split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
@ -505,7 +505,7 @@ def decode_one_batch(
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(tokenize_by_CJK_char(hyp).split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
@ -513,7 +513,7 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(tokenize_by_CJK_char(hyp).split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -523,7 +523,7 @@ def decode_one_batch(
context_graph=context_graph,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(tokenize_by_CJK_char(hyp).split())
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model,
@ -533,7 +533,7 @@ def decode_one_batch(
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(tokenize_by_CJK_char(hyp).split())
elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR(
model=model,
@ -546,7 +546,7 @@ def decode_one_batch(
context_graph=context_graph,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(tokenize_by_CJK_char(hyp).split())
elif params.decoding_method == "modified_beam_search_lm_rescore":
lm_scale_list = [0.01 * i for i in range(10, 50)]
ans_dict = modified_beam_search_lm_rescore(
@ -616,7 +616,7 @@ def decode_one_batch(
ans = dict()
assert ans_dict is not None
for key, hyps in ans_dict.items():
hyps = [sp.decode(hyp).split() for hyp in hyps]
hyps = [tokenize_by_CJK_char(sp.decode(hyp)).split() for hyp in hyps]
ans[f"{prefix}_{key}"] = hyps
return ans
else:
@ -678,7 +678,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list(str(text).replace(" ", "")) for text in texts]
texts = [tokenize_by_CJK_char(text).split() for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(