From 84090511f8939d04106e422a6a69e6fac3216bf2 Mon Sep 17 00:00:00 2001 From: luomingshuang <739314837@qq.com> Date: Thu, 14 Apr 2022 21:45:00 +0800 Subject: [PATCH] add webdataset for dataload --- .../ASR/pruned_transducer_stateless/decode.py | 40 +++++++++---------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless/decode.py index 9d5016346..da0ad9f8f 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless/decode.py @@ -322,12 +322,10 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - import time - st1 = time.time() + encoder_out, encoder_out_lens = model.encoder( x=feature, x_lens=feature_lens ) - ed1 = time.time() hyps = [] if params.decoding_method == "fast_beam_search": @@ -346,15 +344,12 @@ def decode_one_batch( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 ): - st2 = time.time() hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, ) - ed2 = time.time() for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - ed3 = time.time() else: batch_size = encoder_out.size(0) @@ -438,8 +433,6 @@ def decode_dataset( else: log_interval = 2 - import time - ed = time.time() results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] @@ -450,8 +443,6 @@ def decode_dataset( texts = [pinyin(text) for text in texts] for i in range(len(texts)): texts[i] = [token[0] for token in texts[i]] - st = time.time() - print(f"loading data time: {st - ed}") hyps_dict = decode_one_batch( params=params, model=model, @@ -459,7 +450,6 @@ def decode_dataset( batch=batch, decoding_graph=decoding_graph, ) - ed = time.time() for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) @@ -469,14 +459,14 @@ def decode_dataset( results[name].extend(this_batch) num_cuts += len(texts) - + if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" logging.info( f"batch {batch_str}, cuts processed until now is {num_cuts}" ) - + return results @@ -594,11 +584,12 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - - # Note: Please use "pip install webdataset==0.1.103" + + # Note: Please use "pip install webdataset==0.1.103" # for installing the webdataset. - import os import glob + import os + from lhotse import CutSet from lhotse.dataset.webdataset import export_to_webdataset @@ -606,14 +597,15 @@ def main(): test_net = "test_net" test_meet = "test_meet" - if os.path.exists(f"{test_net}/shared-0.tar"): + if not os.path.exists(f"{test_net}/shared-0.tar"): test_net_cuts = wenetspeech.test_net_cuts() export_to_webdataset( test_net_cuts, output_path=f"{test_net}/shared-%d.tar", shard_size=300, ) - if os.path.exists(f"{test_meet}/shared-0.tar"): + + if not os.path.exists(f"{test_meet}/shared-0.tar"): test_meeting_cuts = wenetspeech.test_meeting_cuts() export_to_webdataset( test_meeting_cuts, @@ -621,15 +613,21 @@ def main(): shard_size=300, ) - test_net_shards = [str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))] + test_net_shards = [ + str(path) + for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) + ] cuts_test_net_webdataset = CutSet.from_webdataset( test_net_shards, split_by_worker=True, split_by_node=True, shuffle_shards=True, - ) + ) - test_meet_shards = [str(path) for path in sorted(glob.glob(os.path.join(test_meet, "shared-*.tar")))] + test_meet_shards = [ + str(path) + for path in sorted(glob.glob(os.path.join(test_meet, "shared-*.tar"))) + ] cuts_test_meet_webdataset = CutSet.from_webdataset( test_meet_shards, split_by_worker=True,