From 460ae4cb97ce1fe684009c5e2587b587717cc016 Mon Sep 17 00:00:00 2001 From: luomingshuang <739314837@qq.com> Date: Thu, 14 Apr 2022 20:35:01 +0800 Subject: [PATCH] add webdataset for dataload --- .../asr_datamodule.py | 13 +++- .../ASR/pruned_transducer_stateless/decode.py | 64 ++++++++++++++++--- 2 files changed, 66 insertions(+), 11 deletions(-) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless/asr_datamodule.py index 92c0d8c13..9941dc158 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless/asr_datamodule.py @@ -34,12 +34,14 @@ from lhotse.cut import Cut from lhotse.dataset import ( CutConcatenate, CutMix, + BucketingSampler, DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, SingleCutSampler, SpecAugment, ) +from lhotse.dataset.webdataset import export_to_webdataset from lhotse.dataset.input_strategies import OnTheFlyFeatures from torch.utils.data import DataLoader @@ -361,10 +363,15 @@ class WenetSpeechAsrDataModule: sampler = DynamicBucketingSampler( cuts, max_duration=self.args.max_duration, shuffle=False ) - test_dl = DataLoader( - test, - batch_size=None, + + from lhotse.dataset.iterable_dataset import IterableDatasetWrapper + test_iter_dataset = IterableDatasetWrapper( + dataset=test, sampler=sampler, + ) + test_dl = DataLoader( + test_iter_dataset, + batch_size=None, num_workers=self.args.num_workers, ) return test_dl diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless/decode.py index 744da4db4..9d5016346 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless/decode.py @@ -206,7 +206,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --decoding_method is greedy_search""", ) @@ -322,10 +322,12 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - + import time + st1 = time.time() encoder_out, encoder_out_lens = model.encoder( x=feature, x_lens=feature_lens ) + ed1 = time.time() hyps = [] if params.decoding_method == "fast_beam_search": @@ -344,12 +346,15 @@ def decode_one_batch( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 ): + st2 = time.time() hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, ) + ed2 = time.time() for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + ed3 = time.time() else: batch_size = encoder_out.size(0) @@ -433,6 +438,8 @@ def decode_dataset( else: log_interval = 2 + import time + ed = time.time() results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] @@ -443,7 +450,8 @@ def decode_dataset( texts = [pinyin(text) for text in texts] for i in range(len(texts)): texts[i] = [token[0] for token in texts[i]] - + st = time.time() + print(f"loading data time: {st - ed}") hyps_dict = decode_one_batch( params=params, model=model, @@ -451,6 +459,7 @@ def decode_dataset( batch=batch, decoding_graph=decoding_graph, ) + ed = time.time() for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) @@ -460,13 +469,14 @@ def decode_dataset( results[name].extend(this_batch) num_cuts += len(texts) - + if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" logging.info( f"batch {batch_str}, cuts processed until now is {num_cuts}" ) + return results @@ -584,13 +594,51 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + + # Note: Please use "pip install webdataset==0.1.103" + # for installing the webdataset. + import os + import glob + from lhotse import CutSet + from lhotse.dataset.webdataset import export_to_webdataset wenetspeech = WenetSpeechAsrDataModule(args) - test_net_cuts = wenetspeech.test_net_cuts() - test_meeting_cuts = wenetspeech.test_meeting_cuts() - test_net_dl = wenetspeech.valid_dataloaders(test_net_cuts) - test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) + test_net = "test_net" + test_meet = "test_meet" + if os.path.exists(f"{test_net}/shared-0.tar"): + test_net_cuts = wenetspeech.test_net_cuts() + export_to_webdataset( + test_net_cuts, + output_path=f"{test_net}/shared-%d.tar", + shard_size=300, + ) + if os.path.exists(f"{test_meet}/shared-0.tar"): + test_meeting_cuts = wenetspeech.test_meeting_cuts() + export_to_webdataset( + test_meeting_cuts, + output_path=f"{test_meet}/shared-%d.tar", + shard_size=300, + ) + + test_net_shards = [str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))] + cuts_test_net_webdataset = CutSet.from_webdataset( + test_net_shards, + split_by_worker=True, + split_by_node=True, + shuffle_shards=True, + ) + + test_meet_shards = [str(path) for path in sorted(glob.glob(os.path.join(test_meet, "shared-*.tar")))] + cuts_test_meet_webdataset = CutSet.from_webdataset( + test_meet_shards, + split_by_worker=True, + split_by_node=True, + shuffle_shards=True, + ) + + test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset) + test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meet_webdataset) test_sets = ["TEST_NET", "TEST_MEETING"] test_dl = [test_net_dl, test_meeting_dl]