fixed no cut_id error in decode_dataset (#549)

* fixed import quantization is none

Signed-off-by: shanguanma <nanr9544@gmail.com>

* fixed no cut_id error in decode_dataset

Signed-off-by: shanguanma <nanr9544@gmail.com>

* fixed more than one "#"

Signed-off-by: shanguanma <nanr9544@gmail.com>

* fixed code style

Signed-off-by: shanguanma <nanr9544@gmail.com>

Signed-off-by: shanguanma <nanr9544@gmail.com>
Co-authored-by: shanguanma <nanr9544@gmail.com>
This commit is contained in:
Duo Ma 2022-08-25 10:54:21 +08:00 committed by GitHub
parent 626a26fc2a
commit 0967cf5b38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -81,18 +81,17 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
# hyps is a list, every element is decode result of a sentence.
hyps = hubert_model.ctc_greedy_search(batch) hyps = hubert_model.ctc_greedy_search(batch)
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
assert len(hyps) == len(texts) cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
this_batch = [] this_batch = []
assert len(hyps) == len(texts)
for hyp_text, ref_text in zip(hyps, texts): for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
hyp_words = hyp_text.split() hyp_words = hyp_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results["ctc_greedy_search"].extend(this_batch) results["ctc_greedy_search"].extend(this_batch)
num_cuts += len(texts) num_cuts += len(texts)