From c8cb425e51cfaa6d75c285b83e730dca29fc5f77 Mon Sep 17 00:00:00 2001 From: luomingshuang <739314837@qq.com> Date: Wed, 8 Jun 2022 15:35:53 +0800 Subject: [PATCH] combine the training data and decode without webdataset --- .../asr_datamodule.py | 52 +++++++++---------- .../pruned_transducer_stateless5/decode.py | 45 +++++----------- .../ASR/pruned_transducer_stateless5/train.py | 14 ++--- 3 files changed, 47 insertions(+), 64 deletions(-) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py index 4e3cb87d1..9dcd6fa4b 100644 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -209,13 +209,6 @@ class Aishell4AsrDataModule: help="AudioSamples or PrecomputedFeatures", ) - group.add_argument( - "--training-subset", - type=str, - default="L", - help="The training subset for using", - ) - def train_dataloaders( self, cuts_train: CutSet, @@ -379,18 +372,14 @@ class Aishell4AsrDataModule: valid_sampler = DynamicBucketingSampler( cuts_valid, max_duration=self.args.max_duration, + rank=0, + world_size=1, shuffle=False, ) logging.info("About to create dev dataloader") - - from lhotse.dataset.iterable_dataset import IterableDatasetWrapper - - dev_iter_dataset = IterableDatasetWrapper( - dataset=validate, - sampler=valid_sampler, - ) valid_dl = DataLoader( - dev_iter_dataset, + validate, + sampler=valid_sampler, batch_size=None, num_workers=self.args.num_workers, persistent_workers=False, @@ -409,27 +398,38 @@ class Aishell4AsrDataModule: sampler = DynamicBucketingSampler( cuts, max_duration=self.args.max_duration, + rank=0, + world_size=1, shuffle=False, ) - from lhotse.dataset.iterable_dataset import IterableDatasetWrapper - - test_iter_dataset = IterableDatasetWrapper( - dataset=test, - sampler=sampler, - ) + logging.info("About to create test dataloader") test_dl = DataLoader( - test_iter_dataset, + test, batch_size=None, + sampler=sampler, num_workers=self.args.num_workers, ) return test_dl @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") + def train_S_cuts(self) -> CutSet: + logging.info("About to get S train cuts") return load_manifest_lazy( - self.args.manifest_dir - / f"aishell4_cuts_train_{self.args.training_subset}.jsonl.gz" + self.args.manifest_dir / "aishell4_cuts_train_S.jsonl.gz" + ) + + @lru_cache() + def train_M_cuts(self) -> CutSet: + logging.info("About to get M train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "aishell4_cuts_train_M.jsonl.gz" + ) + + @lru_cache() + def train_L_cuts(self) -> CutSet: + logging.info("About to get L train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "aishell4_cuts_train_L.jsonl.gz" ) @lru_cache() diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py index d52132136..619534519 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py @@ -74,6 +74,8 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) +from lhotse.cut import Cut +from local.text_normalize import text_normalize from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -380,6 +382,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + texts = [list(str(text).replace(" ", "")) for text in texts] hyps_dict = decode_one_batch( params=params, @@ -393,8 +396,7 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) for hyp_words, ref_text in zip(hyps, texts): - ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((ref_text, hyp_words)) results[name].extend(this_batch) @@ -597,38 +599,17 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - # Note: Please use "pip install webdataset==0.1.103" - # for installing the webdataset. - import glob - import os - - from lhotse import CutSet - from lhotse.dataset.webdataset import export_to_webdataset + def text_normalize_for_cut(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = text.strip("\n").strip("\t") + c.supervisions[0].text = text_normalize(text) + return c aishell4 = Aishell4AsrDataModule(args) - - test = "test" - if not os.path.exists(f"{test}/shared-0.tar"): - os.makedirs(test) - test_cuts = aishell4.test_cuts() - export_to_webdataset( - test_cuts, - output_path=f"{test}/shared-%d.tar", - shard_size=300, - ) - - test_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(test, "shared-*.tar"))) - ] - cuts_test_webdataset = CutSet.from_webdataset( - test_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) - - test_dl = aishell4.test_dataloaders(cuts_test_webdataset) + test_cuts = aishell4.test_cuts() + test_cuts = test_cuts.map(text_normalize_for_cut) + test_dl = aishell4.test_dataloaders(test_cuts) test_sets = ["test"] test_dl = [test_dl] diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py index e306ce68b..c2cf5aa66 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py @@ -389,14 +389,14 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 1, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 + "log_interval": 50, + "reset_interval": 100, + "valid_interval": 200, # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, # parameters for Noam - "model_warm_step": 3000, # arg given to model, not for lrate + "model_warm_step": 50, # arg given to model, not for lrate "env_info": get_env_info(), } ) @@ -942,8 +942,10 @@ def run(rank, world_size, args): diagnostic = diagnostics.attach_diagnostics(model, opts) aishell4 = Aishell4AsrDataModule(args) - - train_cuts = aishell4.train_cuts() + # Combine all of the training data + train_cuts = aishell4.train_S_cuts() + train_cuts += aishell4.train_M_cuts() + train_cuts += aishell4.train_L_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds