From c6b2c3b038eefd6984095a480b9866a2b6f494e1 Mon Sep 17 00:00:00 2001 From: luomingshuang <739314837@qq.com> Date: Mon, 18 Apr 2022 12:33:59 +0800 Subject: [PATCH] webdataset dataload for dev --- .../asr_datamodule.py | 19 +++++++++----- .../pruned_transducer_stateless2/decode.py | 25 +++++++++++++++++-- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 780e9a1dd..744db8109 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -360,15 +360,22 @@ class WenetSpeechAsrDataModule: valid_sampler = DynamicBucketingSampler( cuts_valid, max_duration=self.args.max_duration, - buffer_size=30000, + rank=0, + world_size=1, shuffle=False, ) logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, + + from lhotse.dataset.iterable_dataset import IterableDatasetWrapper + + dev_iter_dataset = IterableDatasetWrapper( + dataset=validate, sampler=valid_sampler, + ) + valid_dl = DataLoader( + dev_iter_dataset, batch_size=None, - num_workers=2, + num_workers=self.args.num_workers, persistent_workers=False, ) @@ -410,13 +417,13 @@ class WenetSpeechAsrDataModule: logging.info("use lazy cuts") cuts_train = CutSet.from_jsonl_lazy( self.args.manifest_dir - / "cuts_L_50_pieces.jsonl.gz" + / "cuts_L.jsonl.gz" # use cuts_L_50_pieces.jsonl.gz for original experiments ) else: cuts_train = CutSet.from_file( self.args.manifest_dir - / "cuts_L_50_pieces.jsonl.gz" + / "cuts_L.jsonl.gz" # use cuts_L_50_pieces.jsonl.gz for original experiments ) return cuts_train diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index ab99d087d..49be13d43 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -531,9 +531,18 @@ def main(): wenetspeech = WenetSpeechAsrDataModule(args) + dev = "dev" test_net = "test_net" test_meet = "test_meet" + if not os.path.exists(f"{dev}/shared-0.tar"): + dev_cuts = wenetspeech.valid_cuts() + export_to_webdataset( + dev_cuts, + output_path=f"{dev}/shared-%d.tar", + shard_size=300, + ) + if not os.path.exists(f"{test_net}/shared-0.tar"): test_net_cuts = wenetspeech.test_net_cuts() export_to_webdataset( @@ -550,6 +559,17 @@ def main(): shard_size=300, ) + dev_shards = [ + str(path) + for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) + ] + cuts_dev_webdataset = CutSet.from_webdataset( + dev_shards, + split_by_worker=True, + split_by_node=True, + shuffle_shards=True, + ) + test_net_shards = [ str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) @@ -572,11 +592,12 @@ def main(): shuffle_shards=True, ) + dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset) test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset) test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meet_webdataset) - test_sets = ["TEST_NET", "TEST_MEETING"] - test_dl = [test_net_dl, test_meeting_dl] + test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] + test_dl = [dev_dl, test_net_dl, test_meeting_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset(