add whisper fbank for other dataset

This commit is contained in:
Yuekai Zhang 2024-01-19 15:39:43 +08:00
parent e43c4da91d
commit 315175a362
7 changed files with 120 additions and 24 deletions

View File

@ -19,9 +19,10 @@
import logging
from pathlib import Path
import argparse
import torch
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
@ -30,8 +31,27 @@ from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWri
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
def compute_fbank_kespeech_dev_test():
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser
def compute_fbank_kespeech_dev_test(args):
in_out_dir = Path("data/fbank/kespeech")
# number of workers in dataloader
num_workers = 42
@ -48,7 +68,10 @@ def compute_fbank_kespeech_dev_test():
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
if args.whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device=device))
else:
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
logging.info(f"device: {device}")
@ -86,7 +109,11 @@ def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_kespeech_dev_test()
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
compute_fbank_kespeech_dev_test(args)
if __name__ == "__main__":

View File

@ -25,6 +25,8 @@ from pathlib import Path
import torch
from lhotse import (
CutSet,
WhisperFbank,
WhisperFbankConfig,
KaldifeatFbank,
KaldifeatFbankConfig,
LilcomChunkyWriter,
@ -88,6 +90,20 @@ def get_parser():
default=-1,
help="Stop processing pieces until this number (exclusive).",
)
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser
@ -111,7 +127,10 @@ def compute_fbank_kespeech_splits(args):
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
if args.whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
else:
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
logging.info(f"device: {device}")
set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance

View File

@ -30,7 +30,7 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
@ -42,8 +42,27 @@ from icefall.utils import get_executor
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False):
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser
def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
src_dir = Path("data/manifests/magicdata")
output_dir = Path("data/fbank")
num_jobs = min(30, os.cpu_count())
@ -65,8 +84,11 @@ def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False)
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
if args.whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda"))
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
@ -118,5 +140,5 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_magicdata(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
)

View File

@ -30,7 +30,7 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
@ -43,7 +43,7 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False):
def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
src_dir = Path("data/manifests/primewords")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
@ -65,8 +65,11 @@ def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
@ -108,6 +111,13 @@ def get_args():
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser.parse_args()
@ -118,5 +128,5 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_primewords(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
)

View File

@ -30,7 +30,7 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
@ -43,7 +43,7 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False):
def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
src_dir = Path("data/manifests/stcmds")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
@ -66,7 +66,10 @@ def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False):
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
@ -107,6 +110,12 @@ def get_args():
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser.parse_args()
@ -117,5 +126,5 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_stcmds(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
)

View File

@ -30,7 +30,7 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
@ -43,7 +43,7 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False):
def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
src_dir = Path("data/manifests/thchs30")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
@ -70,7 +70,10 @@ def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False):
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
@ -113,6 +116,12 @@ def get_args():
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser.parse_args()
@ -123,5 +132,5 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_thchs30(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
)

View File

@ -18,7 +18,7 @@
import logging
from pathlib import Path
import argparse
import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter