diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py index 1af08fee2..8f8986655 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py @@ -20,7 +20,7 @@ import logging from pathlib import Path import torch -from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter +from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -30,8 +30,27 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) torch.multiprocessing.set_sharing_strategy("file_system") +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) -def compute_fbank_wenetspeech_dev_test(): + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) + return parser + +def compute_fbank_wenetspeech_dev_test(args): in_out_dir = Path("data/fbank") # number of workers in dataloader num_workers = 42 @@ -44,7 +63,10 @@ def compute_fbank_wenetspeech_dev_test(): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + if args.whisper_fbank: + extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda')) + else: + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) logging.info(f"device: {device}") @@ -82,7 +104,11 @@ def main(): formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_wenetspeech_dev_test() + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + compute_fbank_wenetspeech_dev_test(args) if __name__ == "__main__": diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py index a87801462..82fb0422c 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py @@ -24,6 +24,8 @@ from pathlib import Path import torch from lhotse import ( CutSet, + WhisperFbank, + WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter, @@ -87,6 +89,20 @@ def get_parser(): default=-1, help="Stop processing pieces until this number (excluded).", ) + + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) return parser @@ -110,7 +126,12 @@ def compute_fbank_wenetspeech_splits(args): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + if args.whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device=device) + ) + else: + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) logging.info(f"device: {device}") set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index b0525de60..c1a6c5835 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -182,6 +182,34 @@ if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then fi fi +whisper_mel_bins=80 +if [ $stage -le 129 ] && [ $stop_stage -ge 129 ]; then + log "Stage 129: compute whisper fbank for dev and test sets" + python3 ./local/compute_fbank_wenetspeech_dev_test.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true +fi +if [ $stage -le 130 ] && [ $stop_stage -ge 130 ]; then + log "Stage 130: Comute features for whisper training set" + + split_dir=data/fbank/L_split_${num_splits} + if [ ! -f $split_dir/.split_completed ]; then + lhotse split $num_splits ./data/fbank/cuts_L_raw.jsonl.gz $split_dir + touch $split_dir/.split_completed + fi + + python3 ./local/compute_fbank_wenetspeech_splits.py \ + --training-subset L \ + --num-workers 20 \ + --batch-duration 600 \ + --start 0 \ + --num-mel-bins ${whisper_mel_bins} --whisper-fbank true \ + --num-splits $num_splits + + if [ ! -f data/fbank/cuts_L.jsonl.gz ]; then + pieces=$(find data/fbank/L_split_1000 -name "cuts_L.*.jsonl.gz") + lhotse combine $pieces data/fbank/cuts_L.jsonl.gz + fi +fi + if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then log "Stage 14: Compute fbank for musan" mkdir -p data/fbank