diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index f33a687e8..780e9a1dd 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -383,12 +383,22 @@ class WenetSpeechAsrDataModule: return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( - cuts, max_duration=self.args.max_duration, shuffle=False + cuts, + max_duration=self.args.max_duration, + rank=0, + world_size=1, + shuffle=False, + ) + + from lhotse.dataset.iterable_dataset import IterableDatasetWrapper + + test_iter_dataset = IterableDatasetWrapper( + dataset=test, + sampler=sampler, ) test_dl = DataLoader( - test, + test_iter_dataset, batch_size=None, - sampler=sampler, num_workers=self.args.num_workers, ) return test_dl diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index e68417f0a..ab99d087d 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -129,10 +129,22 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--lang-dir", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--token-type", + type=str, + default="char", + help="""The type of token + It must be in ["char", "pinyin", "lazy_pinyin"] + """, ) parser.add_argument( @@ -268,8 +280,10 @@ def decode_one_batch( model=model, encoder_out=encoder_out, ) + # print(hyp_tokens) + # print(lexicon.token_table) for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens]) + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) elif params.decoding_method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -277,7 +291,7 @@ def decode_one_batch( beam=params.beam_size, ) for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens]) + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) else: batch_size = encoder_out.size(0) @@ -358,6 +372,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + texts = [list(str(text)) for text in texts] hyps_dict = decode_one_batch( params=params, @@ -371,8 +386,7 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) for hyp_words, ref_text in zip(hyps, texts): - ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((ref_text, hyp_words)) results[name].extend(this_batch) @@ -507,12 +521,59 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - wenetspeech = WenetSpeechAsrDataModule(args) - test_net_cuts = wenetspeech.test_net_cuts() - test_meeting_cuts = wenetspeech.test_meeting_cuts() + # Note: Please use "pip install webdataset==0.1.103" + # for installing the webdataset. + import glob + import os - test_net_dl = wenetspeech.valid_dataloaders(test_net_cuts) - test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) + from lhotse import CutSet + from lhotse.dataset.webdataset import export_to_webdataset + + wenetspeech = WenetSpeechAsrDataModule(args) + + test_net = "test_net" + test_meet = "test_meet" + + if not os.path.exists(f"{test_net}/shared-0.tar"): + test_net_cuts = wenetspeech.test_net_cuts() + export_to_webdataset( + test_net_cuts, + output_path=f"{test_net}/shared-%d.tar", + shard_size=300, + ) + + if not os.path.exists(f"{test_meet}/shared-0.tar"): + test_meeting_cuts = wenetspeech.test_meeting_cuts() + export_to_webdataset( + test_meeting_cuts, + output_path=f"{test_meet}/shared-%d.tar", + shard_size=300, + ) + + test_net_shards = [ + str(path) + for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) + ] + cuts_test_net_webdataset = CutSet.from_webdataset( + test_net_shards, + split_by_worker=True, + split_by_node=True, + shuffle_shards=True, + ) + + test_meet_shards = [ + str(path) + for path in sorted(glob.glob(os.path.join(test_meet, "shared-*.tar"))) + ] + cuts_test_meet_webdataset = CutSet.from_webdataset( + test_meet_shards, + split_by_worker=True, + split_by_node=True, + shuffle_shards=True, + ) + + test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset) + test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meet_webdataset) test_sets = ["TEST_NET", "TEST_MEETING"] test_dl = [test_net_dl, test_meeting_dl] diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index 449850344..7475632a6 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -45,6 +45,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging +import os import warnings from pathlib import Path from shutil import copyfile @@ -84,6 +85,8 @@ LRSchedulerType = Union[ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler ] +os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + def get_parser(): parser = argparse.ArgumentParser( @@ -332,7 +335,7 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 50, "reset_interval": 200, - "valid_interval": 2000, + "valid_interval": 3000, # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, @@ -867,6 +870,7 @@ def run(rank, world_size, args): wenetspeech = WenetSpeechAsrDataModule(args) train_cuts = wenetspeech.train_cuts() + valid_cuts = wenetspeech.valid_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -890,8 +894,8 @@ def run(rank, world_size, args): train_cuts = train_cuts.filter(remove_short_and_long_utt) if params.token_type == "pinyin": train_cuts = train_cuts.map(text_to_words) + # valid_cuts = valid_cuts.map(text_to_words) - valid_cuts = wenetspeech.valid_cuts() valid_dl = wenetspeech.valid_dataloaders(valid_cuts) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: