From 22e9d837f8dcc6d27c9ec1929a81cd70609081c2 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 22 Sep 2023 12:38:23 +0800 Subject: [PATCH] Fix train.py --- .../ASR/zipformer/asr_datamodule.py | 28 +++--------- egs/libriheavy/ASR/zipformer/train.py | 45 ++++++++++++++----- 2 files changed, 39 insertions(+), 34 deletions(-) diff --git a/egs/libriheavy/ASR/zipformer/asr_datamodule.py b/egs/libriheavy/ASR/zipformer/asr_datamodule.py index 9db018468..1a6d833a6 100644 --- a/egs/libriheavy/ASR/zipformer/asr_datamodule.py +++ b/egs/libriheavy/ASR/zipformer/asr_datamodule.py @@ -95,16 +95,6 @@ class LibriHeavyAsrDataModule: """, ) - group.add_argument( - "--with-punctuation", - type=str2bool, - default=False, - help="""True to train the model on transcription with punctuation, - False to train the model on normalized transcription - (upper without punctuation). - """, - ) - group.add_argument( "--manifest-dir", type=Path, @@ -413,35 +403,29 @@ class LibriHeavyAsrDataModule: @lru_cache() def train_small_cuts(self) -> CutSet: logging.info("About to get small subset cuts") - basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_small.jsonl.gz" - return load_manifest_lazy(self.args.manifest_dir / basename) + return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz") @lru_cache() def train_medium_cuts(self) -> CutSet: logging.info("About to get medium subset cuts") - basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_medium.jsonl.gz" - return load_manifest_lazy(self.args.manifest_dir / basename) + return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz") @lru_cache() def train_large_cuts(self) -> CutSet: logging.info("About to get large subset cuts") - basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_large.jsonl.gz" - return load_manifest_lazy(self.args.manifest_dir / basename) + return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_large.jsonl.gz") @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_dev.jsonl.gz" - return load_manifest_lazy(self.args.manifest_dir / basename) + return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz") @lru_cache() def test_clean_cuts(self) -> CutSet: logging.info("About to get the test-clean cuts") - basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_test-clean.jsonl.gz" - return load_manifest_lazy(self.args.manifest_dir / basename) + return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_test_clean.jsonl.gz") @lru_cache() def test_other_cuts(self) -> CutSet: logging.info("About to get the test-other cuts") - basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_test-other.jsonl.gz" - return load_manifest_lazy(self.args.manifest_dir / basename) + return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_test_other.jsonl.gz") diff --git a/egs/libriheavy/ASR/zipformer/train.py b/egs/libriheavy/ASR/zipformer/train.py index 1d8587ee3..4c6ee976f 100644 --- a/egs/libriheavy/ASR/zipformer/train.py +++ b/egs/libriheavy/ASR/zipformer/train.py @@ -261,6 +261,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): ) + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -470,6 +471,13 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--train-with-punctuation", + type=str2bool, + default=False, + help="If True, the training text will include casing and punctuation.", + ) + add_model_arguments(parser) return parser @@ -1178,9 +1186,20 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - libriheavy = LibriHeavyAsrDataModule(args) - train_cuts = libriheavy.train_cuts() + def normalize_text(c: Cut): + text = c.supervisions[0].text + if params.train_with_punctuation: + table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]") + text = text.translate(table) + else: + text = text.replace("‘", "'") + text = text.replace("’", "'") + tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") + text_list = [x.upper() if x in tokens else " " for x in text] + text = " ".join("".join(text_list).split()).strip() + c.supervisions[0].text = text + return c def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1204,7 +1223,7 @@ def run(rank, world_size, args): # In ./zipformer.py, the conv module uses the following expression # for subsampling T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].texts[0], out_type=str) + tokens = sp.encode(c.supervisions[0].text, out_type=str) if T < len(tokens): logging.warning( @@ -1219,6 +1238,15 @@ def run(rank, world_size, args): return True + libriheavy = LibriHeavyAsrDataModule(args) + + train_cuts = libriheavy.train_small_cuts() + if params.subset == 'M' or params.subset == 'L': + train_cuts += libriheavy.train_medium_cuts() + if params.subset == "L": + train_cuts += libriheavy.train_large_cuts() + train_cuts = train_cuts.map(normalize_text) + train_cuts = train_cuts.filter(remove_short_and_long_utt) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: @@ -1232,15 +1260,8 @@ def run(rank, world_size, args): train_cuts, sampler_state_dict=sampler_state_dict ) - def add_texts(c: Cut): - text = c.supervisions[0].text - c.supervisions[0].texts = [text] - return c - - valid_cuts = libriheavy.librispeech_dev_clean_cuts() - valid_cuts += libriheavy.librispeech_dev_other_cuts() - - valid_cuts = valid_cuts.map(add_texts) + valid_cuts = libriheavy.dev_cuts() + valid_cuts = valid_cuts.map(normalize_text) valid_dl = libriheavy.valid_dataloaders(valid_cuts)