From b8a34ae160fd5bf8f8c356609ef01b6f35ca5899 Mon Sep 17 00:00:00 2001 From: zr_jin <60612200+JinZr@users.noreply.github.com> Date: Thu, 20 Jul 2023 16:53:04 +0800 Subject: [PATCH] disable speed perturbation by default --- .../local/compute_fbank_aidatatang_200zh.py | 22 ++++++++++++++----- .../ASR/local/compute_fbank_aishell.py | 22 ++++++++++++++----- .../ASR/local/compute_fbank_aishell2.py | 21 +++++++++++++----- .../ASR/local/compute_fbank_aishell4.py | 22 ++++++++++++++----- .../ASR/local/compute_fbank_alimeeting.py | 21 +++++++++++++----- .../ASR/local/preprocess_wenetspeech.py | 22 ++++++++++++++++--- 6 files changed, 100 insertions(+), 30 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py index 387c14acf..b0c8f8434 100755 --- a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py +++ b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py @@ -42,7 +42,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): +def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80, speed_perturb: bool = False): src_dir = Path("data/manifests/aidatatang_200zh") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -86,9 +86,12 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): supervisions=m["supervisions"], ) if "train" in partition: - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) + if speed_perturb: + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) cut_set = cut_set.compute_and_store_features( extractor=extractor, storage_path=f"{output_dir}/{prefix}_feats_{partition}", @@ -109,7 +112,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) - + parser.add_argument( + "--speed-perturb", + type=bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -119,4 +127,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins) + compute_fbank_aidatatang_200zh( + num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + ) diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py index 115ca1031..c8e4f8fe0 100755 --- a/egs/aishell/ASR/local/compute_fbank_aishell.py +++ b/egs/aishell/ASR/local/compute_fbank_aishell.py @@ -42,7 +42,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aishell(num_mel_bins: int = 80): +def compute_fbank_aishell(num_mel_bins: int = 80, speed_perturb: bool = False): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -82,9 +82,12 @@ def compute_fbank_aishell(num_mel_bins: int = 80): supervisions=m["supervisions"], ) if "train" in partition: - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) + if speed_perturb: + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) cut_set = cut_set.compute_and_store_features( extractor=extractor, storage_path=f"{output_dir}/{prefix}_feats_{partition}", @@ -104,7 +107,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) - + parser.add_argument( + "--speed-perturb", + type=bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -114,4 +122,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_aishell(num_mel_bins=args.num_mel_bins) + compute_fbank_aishell( + num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + ) diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py index ec0c584ca..8e73711b0 100755 --- a/egs/aishell2/ASR/local/compute_fbank_aishell2.py +++ b/egs/aishell2/ASR/local/compute_fbank_aishell2.py @@ -42,7 +42,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aishell2(num_mel_bins: int = 80): +def compute_fbank_aishell2(num_mel_bins: int = 80, speed_perturb: bool = False): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -82,9 +82,12 @@ def compute_fbank_aishell2(num_mel_bins: int = 80): supervisions=m["supervisions"], ) if "train" in partition: - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) + if speed_perturb: + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) cut_set = cut_set.compute_and_store_features( extractor=extractor, storage_path=f"{output_dir}/{prefix}_feats_{partition}", @@ -104,6 +107,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) + parser.add_argument( + "--speed-perturb", + type=bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -114,4 +123,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_aishell2(num_mel_bins=args.num_mel_bins) + compute_fbank_aishell2( + num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + ) diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py index 400c406f0..6e5289cf8 100755 --- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py +++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py @@ -42,7 +42,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aishell4(num_mel_bins: int = 80): +def compute_fbank_aishell4(num_mel_bins: int = 80, speed_perturb: bool = False): src_dir = Path("data/manifests/aishell4") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -84,9 +84,13 @@ def compute_fbank_aishell4(num_mel_bins: int = 80): supervisions=m["supervisions"], ) if "train" in partition: - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) + if speed_perturb: + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) + cut_set = cut_set.compute_and_store_features( extractor=extractor, storage_path=f"{output_dir}/{prefix}_feats_{partition}", @@ -113,6 +117,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) + parser.add_argument( + "--speed-perturb", + type=bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -123,4 +133,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_aishell4(num_mel_bins=args.num_mel_bins) + compute_fbank_aishell4( + num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + ) diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py index 96115a230..b51b4e6c6 100755 --- a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py +++ b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py @@ -42,7 +42,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_alimeeting(num_mel_bins: int = 80): +def compute_fbank_alimeeting(num_mel_bins: int = 80, speed_perturb: bool = False): src_dir = Path("data/manifests/alimeeting") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -83,9 +83,12 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80): supervisions=m["supervisions"], ) if "train" in partition: - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) + if speed_perturb: + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) cur_num_jobs = num_jobs if ex is None else 80 cur_num_jobs = min(cur_num_jobs, len(cut_set)) @@ -114,6 +117,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) + parser.add_argument( + "--speed-perturb", + type=bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -124,4 +133,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_alimeeting(num_mel_bins=args.num_mel_bins) + compute_fbank_alimeeting( + num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + ) diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py index 93ce750f8..1bd210f07 100755 --- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py +++ b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import logging import re from pathlib import Path @@ -45,7 +46,7 @@ def has_no_oov( return oov_pattern.search(sup.text) is None -def preprocess_wenet_speech(): +def preprocess_wenet_speech(speed_perturb: bool = False): src_dir = Path("data/manifests") output_dir = Path("data/fbank") output_dir.mkdir(exist_ok=True) @@ -115,15 +116,30 @@ def preprocess_wenet_speech(): f"Speed perturb for {partition} with factors 0.9 and 1.1 " "(Perturbing may take 8 minutes and saving may take 20 minutes)" ) - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + if speed_perturb: + cut_set = ( + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + ) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--speed-perturb", + type=bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) + return parser.parse_args() + + def main(): setup_logger(log_filename="./log-preprocess-wenetspeech") - preprocess_wenet_speech() + args = get_args() + preprocess_wenet_speech(speed_perturb=args.speed_perturb) logging.info("Done")