Minor fixes

This commit is contained in:
pkufool 2023-06-21 18:13:24 +08:00
parent a7d0588827
commit 63e53bad59
2 changed files with 4 additions and 7 deletions

View File

@ -588,7 +588,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list(str(text)) for text in texts]
texts = [list("".join(text.split())) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(

View File

@ -520,7 +520,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list(str(text)) for text in texts]
texts = [list("".join(text.split())) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
@ -793,12 +793,9 @@ def main():
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
test_dls = [dev_dl, test_net_dl, test_meeting_dl]
test_sets = ["TEST_NET"]
test_dl = [test_net_dl]
for test_set, test_dl in zip(test_sets, test_dl):
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,