mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 15:44:17 +00:00
minor updates
This commit is contained in:
parent
8467eb7c26
commit
2f1c36013d
@ -470,7 +470,6 @@ def decode_one_batch(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(smart_byte_decode(sp.decode(hyp)).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
@ -535,7 +534,9 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
texts = [list(str(text).replace(" ", "")) for text in texts]
|
||||
texts = [tokenize_by_CJK_char(str(text)).split() for text in texts]
|
||||
# print(texts)
|
||||
# exit()
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
@ -551,8 +552,7 @@ def decode_dataset(
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
hyp_text = "".join(hyp_words)
|
||||
this_batch.append((cut_id, ref_text, hyp_text))
|
||||
this_batch.append((cut_id, ref_text, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
|
@ -1186,7 +1186,7 @@ def run(rank, world_size, args):
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if c.duration < 1.0 or c.duration > 12.0:
|
||||
if c.duration < 1.0:
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user