fixed no cut_id error in decode_dataset (#549)

* fixed import quantization is none

Signed-off-by: shanguanma <nanr9544@gmail.com>

* fixed no cut_id error in decode_dataset

Signed-off-by: shanguanma <nanr9544@gmail.com>

* fixed more than one "#"

Signed-off-by: shanguanma <nanr9544@gmail.com>

* fixed code style

Signed-off-by: shanguanma <nanr9544@gmail.com>

Signed-off-by: shanguanma <nanr9544@gmail.com>
Co-authored-by: shanguanma <nanr9544@gmail.com>
This commit is contained in:
Duo Ma 2022-08-25 10:54:21 +08:00 committed by GitHub
parent 626a26fc2a
commit 0967cf5b38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -81,18 +81,17 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
# hyps is a list, every element is decode result of a sentence.
hyps = hubert_model.ctc_greedy_search(batch)
texts = batch["supervisions"]["text"]
assert len(hyps) == len(texts)
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
this_batch = []
for hyp_text, ref_text in zip(hyps, texts):
assert len(hyps) == len(texts)
for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
hyp_words = hyp_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results["ctc_greedy_search"].extend(this_batch)
num_cuts += len(texts)