minor updates

This commit is contained in:
JinZr 2023-09-28 09:24:31 +08:00
parent 8467eb7c26
commit 2f1c36013d
2 changed files with 5 additions and 5 deletions

View File

@ -470,7 +470,6 @@ def decode_one_batch(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(smart_byte_decode(sp.decode(hyp)).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
@ -535,7 +534,9 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list(str(text).replace(" ", "")) for text in texts]
texts = [tokenize_by_CJK_char(str(text)).split() for text in texts]
# print(texts)
# exit()
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
@ -551,8 +552,7 @@ def decode_dataset(
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
hyp_text = "".join(hyp_words)
this_batch.append((cut_id, ref_text, hyp_text))
this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch)

View File

@ -1186,7 +1186,7 @@ def run(rank, world_size, args):
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 12.0:
if c.duration < 1.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)