mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
fixed no cut_id error in decode_dataset (#549)
* fixed import quantization is none Signed-off-by: shanguanma <nanr9544@gmail.com> * fixed no cut_id error in decode_dataset Signed-off-by: shanguanma <nanr9544@gmail.com> * fixed more than one "#" Signed-off-by: shanguanma <nanr9544@gmail.com> * fixed code style Signed-off-by: shanguanma <nanr9544@gmail.com> Signed-off-by: shanguanma <nanr9544@gmail.com> Co-authored-by: shanguanma <nanr9544@gmail.com>
This commit is contained in:
parent
626a26fc2a
commit
0967cf5b38
@ -81,18 +81,17 @@ def decode_dataset(
|
|||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
# hyps is a list, every element is decode result of a sentence.
|
||||||
hyps = hubert_model.ctc_greedy_search(batch)
|
hyps = hubert_model.ctc_greedy_search(batch)
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
assert len(hyps) == len(texts)
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
this_batch = []
|
this_batch = []
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
for hyp_text, ref_text in zip(hyps, texts):
|
for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
hyp_words = hyp_text.split()
|
hyp_words = hyp_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results["ctc_greedy_search"].extend(this_batch)
|
results["ctc_greedy_search"].extend(this_batch)
|
||||||
|
|
||||||
num_cuts += len(texts)
|
num_cuts += len(texts)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user