From ffb0e7891d610fdd8f031a23518472b7911bf128 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Thu, 20 Jul 2023 16:36:12 +0800 Subject: [PATCH] disable speed perturbation by default --- .../ASR/local/compute_fbank_magicdata.py | 16 ++- .../ASR/local/compute_fbank_primewords.py | 16 ++- .../ASR/local/compute_fbank_stcmds.py | 17 ++- .../ASR/local/compute_fbank_thchs30.py | 17 ++- .../ASR/local/preprocess_kespeech.py | 30 +++- .../ASR/zipformer/multi_dataset.py | 128 ++++++++++++++++++ 6 files changed, 203 insertions(+), 21 deletions(-) create mode 100644 egs/multi_zh-hans/ASR/zipformer/multi_dataset.py diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py b/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py index 22a928cf3..1ddd72377 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py @@ -43,7 +43,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_thchs30(num_mel_bins: int = 80): +def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False): src_dir = Path("data/manifests/magicdata") output_dir = Path("data/fbank") num_jobs = min(30, os.cpu_count()) @@ -80,7 +80,9 @@ def compute_fbank_thchs30(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + (cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)) + if speed_perturb + else cut_set ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -101,6 +103,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) + parser.add_argument( + "--speed-perturb", + type=bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -111,4 +119,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_thchs30(num_mel_bins=args.num_mel_bins) + compute_fbank_thchs30( + num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + ) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py b/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py index 0ea414548..d332e0067 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py @@ -43,7 +43,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_thchs30(num_mel_bins: int = 80): +def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False): src_dir = Path("data/manifests/primewords") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -80,7 +80,9 @@ def compute_fbank_thchs30(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + (cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)) + if speed_perturb + else cut_set ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -101,6 +103,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) + parser.add_argument( + "--speed-perturb", + type=bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -111,4 +119,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_thchs30(num_mel_bins=args.num_mel_bins) + compute_fbank_thchs30( + num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + ) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py b/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py index 4083a4b9e..b5deddeab 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py @@ -43,7 +43,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_thchs30(num_mel_bins: int = 80): +def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False): src_dir = Path("data/manifests/stcmds") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -80,7 +80,9 @@ def compute_fbank_thchs30(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + (cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)) + if speed_perturb + else cut_set ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -101,7 +103,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) - + parser.add_argument( + "--speed-perturb", + type=bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -111,4 +118,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_thchs30(num_mel_bins=args.num_mel_bins) + compute_fbank_thchs30( + num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + ) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py b/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py index b79232d91..58bb8002a 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py @@ -43,7 +43,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_thchs30(num_mel_bins: int = 80): +def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False): src_dir = Path("data/manifests/thchs30") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -84,7 +84,9 @@ def compute_fbank_thchs30(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + (cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)) + if speed_perturb + else cut_set ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -105,7 +107,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) - + parser.add_argument( + "--speed-perturb", + type=bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -115,4 +122,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_thchs30(num_mel_bins=args.num_mel_bins) + compute_fbank_thchs30( + num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + ) diff --git a/egs/multi_zh-hans/ASR/local/preprocess_kespeech.py b/egs/multi_zh-hans/ASR/local/preprocess_kespeech.py index 00eed113f..5d871a5c6 100755 --- a/egs/multi_zh-hans/ASR/local/preprocess_kespeech.py +++ b/egs/multi_zh-hans/ASR/local/preprocess_kespeech.py @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import logging import re from pathlib import Path @@ -45,7 +46,7 @@ def has_no_oov( return oov_pattern.search(sup.text) is None -def preprocess_kespeech(): +def preprocess_kespeech(speed_perturb: bool = False): src_dir = Path("data/manifests/kespeech") output_dir = Path("data/fbank/kespeech") output_dir.mkdir(exist_ok=True) @@ -114,19 +115,34 @@ def preprocess_kespeech(): "dev_phase2", "test", ]: - logging.info( - f"Speed perturb for {partition} with factors 0.9 and 1.1 " - "(Perturbing may take 8 minutes and saving may take 20 minutes)" - ) - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + if speed_perturb: + logging.info( + f"Speed perturb for {partition} with factors 0.9 and 1.1 " + "(Perturbing may take 8 minutes and saving may take 20 minutes)" + ) + cut_set = ( + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + ) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--speed-perturb", + type=bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) + return parser.parse_args() + + def main(): setup_logger(log_filename="./log-preprocess-kespeech") - preprocess_kespeech() + args = get_args() + preprocess_kespeech(speed_perturb=args.speed_perturb) logging.info("Done") diff --git a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py new file mode 100644 index 000000000..856130ee4 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py @@ -0,0 +1,128 @@ +# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import glob +import logging +import re +from pathlib import Path + +import lhotse +from lhotse import CutSet, load_manifest_lazy + + +class MultiDataset: + def __init__(self, fbank_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files: + - aidatatang_cuts_train.jsonl.gz + - aishell_cuts_train.jsonl.gz + - aishell2_cuts_train.jsonl.gz + - aishell4_cuts_train_L.jsonl.gz + - aishell4_cuts_train_M.jsonl.gz + - aishell4_cuts_train_S.jsonl.gz + - alimeeting-far_cuts_train.jsonl.gz + - cuts_L.jsonl.gz + - cuts_M.jsonl.gz + - cuts_S.jsonl.gz + - primewords_cuts_train.jsonl.gz + - stcmds_cuts_train.jsonl.gz + - thchs_30_cuts_train.jsonl.gz + """ + self.fbank_dir = Path(fbank_dir) + + def train_cuts(self) -> CutSet: + logging.info("About to get multidataset train cuts") + + # LibriSpeech + logging.info("Loading LibriSpeech in lazy mode") + librispeech_cuts = load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + # GigaSpeech + filenames = glob.glob(f"{self.fbank_dir}/XL_split/cuts_XL.*.jsonl.gz") + + pattern = re.compile(r"cuts_XL.([0-9]+).jsonl.gz") + idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) + idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) + + sorted_filenames = [f[1] for f in idx_filenames] + + logging.info(f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode") + + gigaspeech_cuts = lhotse.combine( + lhotse.load_manifest_lazy(p) for p in sorted_filenames + ) + + # CommonVoice + logging.info("Loading CommonVoice in lazy mode") + commonvoice_cuts = load_manifest_lazy( + self.fbank_dir / f"cv-en_cuts_train.jsonl.gz" + ) + + # LibriHeavy + logging.info("Loading LibriHeavy in lazy mode") + libriheavy_small_cuts = load_manifest_lazy( + self.fbank_dir / "libriheavy_cuts_train_small.jsonl.gz" + ) + libriheavy_medium_cuts = load_manifest_lazy( + self.fbank_dir / "libriheavy_cuts_train_medium.jsonl.gz" + ) + libriheavy_cuts = lhotse.combine(libriheavy_small_cuts, libriheavy_medium_cuts) + + return CutSet.mux( + librispeech_cuts, + gigaspeech_cuts, + commonvoice_cuts, + libriheavy_cuts, + weights=[ + len(librispeech_cuts), + len(gigaspeech_cuts), + len(commonvoice_cuts), + len(libriheavy_cuts), + ], + ) + + def test_cuts(self) -> CutSet: + logging.info("About to get multidataset test cuts") + + # GigaSpeech + logging.info("Loading GigaSpeech DEV in lazy mode") + gigaspeech_dev_cuts = load_manifest_lazy(self.fbank_dir / "cuts_DEV.jsonl.gz") + + logging.info("Loading GigaSpeech TEST in lazy mode") + gigaspeech_test_cuts = load_manifest_lazy(self.fbank_dir / "cuts_TEST.jsonl.gz") + + # CommonVoice + logging.info("Loading CommonVoice DEV in lazy mode") + commonvoice_dev_cuts = load_manifest_lazy( + self.fbank_dir / "cv-en_cuts_dev.jsonl.gz" + ) + + logging.info("Loading CommonVoice TEST in lazy mode") + commonvoice_test_cuts = load_manifest_lazy( + self.fbank_dir / "cv-en_cuts_test.jsonl.gz" + ) + + return [ + gigaspeech_dev_cuts, + gigaspeech_test_cuts, + commonvoice_dev_cuts, + commonvoice_test_cuts, + ]