diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py index 2581ee42f..3863133c9 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py @@ -19,9 +19,10 @@ import logging from pathlib import Path +import argparse import torch -from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter +from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -30,8 +31,27 @@ from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWri torch.set_num_threads(1) torch.set_num_interop_threads(1) +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) -def compute_fbank_kespeech_dev_test(): + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) + return parser + +def compute_fbank_kespeech_dev_test(args): in_out_dir = Path("data/fbank/kespeech") # number of workers in dataloader num_workers = 42 @@ -48,7 +68,10 @@ def compute_fbank_kespeech_dev_test(): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + if args.whisper_fbank: + extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)) + else: + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) logging.info(f"device: {device}") @@ -86,7 +109,11 @@ def main(): formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_kespeech_dev_test() + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + compute_fbank_kespeech_dev_test(args) if __name__ == "__main__": diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py index 8bfbc7b50..bef2e23d2 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py @@ -25,6 +25,8 @@ from pathlib import Path import torch from lhotse import ( CutSet, + WhisperFbank, + WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter, @@ -88,6 +90,20 @@ def get_parser(): default=-1, help="Stop processing pieces until this number (exclusive).", ) + + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) return parser @@ -111,7 +127,10 @@ def compute_fbank_kespeech_splits(args): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + if args.whisper_fbank: + extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda')) + else: + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) logging.info(f"device: {device}") set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py b/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py index 5649d3815..ad08b9f19 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py @@ -30,7 +30,7 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor @@ -42,8 +42,27 @@ from icefall.utils import get_executor torch.set_num_threads(1) torch.set_num_interop_threads(1) +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) -def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False): + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) + return parser + +def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False): src_dir = Path("data/manifests/magicdata") output_dir = Path("data/fbank") num_jobs = min(30, os.cpu_count()) @@ -65,8 +84,11 @@ def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False) list(manifests.keys()), dataset_parts, ) - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + if args.whisper_fbank: + extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -118,5 +140,5 @@ if __name__ == "__main__": args = get_args() compute_fbank_magicdata( - num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank ) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py b/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py index 303a16580..b60787fd8 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py @@ -30,7 +30,7 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor @@ -43,7 +43,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False): +def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False): src_dir = Path("data/manifests/primewords") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -65,8 +65,11 @@ def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False list(manifests.keys()), dataset_parts, ) - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + if whisper_fbank: + extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda')) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -108,6 +111,13 @@ def get_args(): help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", ) + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) + return parser.parse_args() @@ -118,5 +128,5 @@ if __name__ == "__main__": args = get_args() compute_fbank_primewords( - num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank ) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py b/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py index 730806954..f156f5fbb 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py @@ -30,7 +30,7 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor @@ -43,7 +43,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False): +def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False): src_dir = Path("data/manifests/stcmds") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -66,7 +66,10 @@ def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False): dataset_parts, ) - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + if whisper_fbank: + extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda')) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -107,6 +110,12 @@ def get_args(): default=False, help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", ) + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) return parser.parse_args() @@ -117,5 +126,5 @@ if __name__ == "__main__": args = get_args() compute_fbank_stcmds( - num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank ) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py b/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py index 58bb8002a..5fc462e47 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py @@ -30,7 +30,7 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor @@ -43,7 +43,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False): +def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False): src_dir = Path("data/manifests/thchs30") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -70,7 +70,10 @@ def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False): dataset_parts, ) - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + if whisper_fbank: + extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda')) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -113,6 +116,12 @@ def get_args(): default=False, help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", ) + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) return parser.parse_args() @@ -123,5 +132,5 @@ if __name__ == "__main__": args = get_args() compute_fbank_thchs30( - num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank ) diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py index 8f8986655..43a1febba 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py @@ -18,7 +18,7 @@ import logging from pathlib import Path - +import argparse import torch from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter