mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
minor updates
This commit is contained in:
parent
8467eb7c26
commit
2f1c36013d
@ -470,7 +470,6 @@ def decode_one_batch(
|
|||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
)
|
)
|
||||||
hyps.append(smart_byte_decode(sp.decode(hyp)).split())
|
hyps.append(smart_byte_decode(sp.decode(hyp)).split())
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
elif "fast_beam_search" in params.decoding_method:
|
elif "fast_beam_search" in params.decoding_method:
|
||||||
@ -535,7 +534,9 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
texts = [list(str(text).replace(" ", "")) for text in texts]
|
texts = [tokenize_by_CJK_char(str(text)).split() for text in texts]
|
||||||
|
# print(texts)
|
||||||
|
# exit()
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
@ -551,8 +552,7 @@ def decode_dataset(
|
|||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
hyp_text = "".join(hyp_words)
|
this_batch.append((cut_id, ref_text, hyp_words))
|
||||||
this_batch.append((cut_id, ref_text, hyp_text))
|
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
|
@ -1186,7 +1186,7 @@ def run(rank, world_size, args):
|
|||||||
# You should use ../local/display_manifest_statistics.py to get
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
if c.duration < 1.0 or c.duration > 12.0:
|
if c.duration < 1.0:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user