From 63e53bad5963397bbe6fd7ffdaa7d64a792cf9f4 Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 21 Jun 2023 18:13:24 +0800 Subject: [PATCH] Minor fixes --- .../ASR/pruned_transducer_stateless5/decode.py | 2 +- egs/wenetspeech/ASR/zipformer/decode.py | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index dc431578c..36b8a4b67 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -588,7 +588,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] - texts = [list(str(text)) for text in texts] + texts = [list("".join(text.split())) for text in texts] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( diff --git a/egs/wenetspeech/ASR/zipformer/decode.py b/egs/wenetspeech/ASR/zipformer/decode.py index 9a752b7ca..0fbc8244b 100755 --- a/egs/wenetspeech/ASR/zipformer/decode.py +++ b/egs/wenetspeech/ASR/zipformer/decode.py @@ -520,7 +520,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] - texts = [list(str(text)) for text in texts] + texts = [list("".join(text.split())) for text in texts] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( @@ -793,12 +793,9 @@ def main(): test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] - test_dl = [dev_dl, test_net_dl, test_meeting_dl] + test_dls = [dev_dl, test_net_dl, test_meeting_dl] - test_sets = ["TEST_NET"] - test_dl = [test_net_dl] - - for test_set, test_dl in zip(test_sets, test_dl): + for test_set, test_dl in zip(test_sets, test_dls): results_dict = decode_dataset( dl=test_dl, params=params,