add whisper fbank for other dataset

This commit is contained in:
Yuekai Zhang 2024-01-19 15:39:43 +08:00
parent e43c4da91d
commit 315175a362
7 changed files with 120 additions and 24 deletions

View File

@ -19,9 +19,10 @@
import logging import logging
from pathlib import Path from pathlib import Path
import argparse
import torch import torch
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
# Torch's multithreaded behavior needs to be disabled or # Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down. # it wastes a lot of CPU and slow things down.
@ -30,8 +31,27 @@ from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWri
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
def compute_fbank_kespeech_dev_test(): parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser
def compute_fbank_kespeech_dev_test(args):
in_out_dir = Path("data/fbank/kespeech") in_out_dir = Path("data/fbank/kespeech")
# number of workers in dataloader # number of workers in dataloader
num_workers = 42 num_workers = 42
@ -48,6 +68,9 @@ def compute_fbank_kespeech_dev_test():
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
if args.whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device=device))
else:
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
logging.info(f"device: {device}") logging.info(f"device: {device}")
@ -86,7 +109,11 @@ def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_kespeech_dev_test() parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
compute_fbank_kespeech_dev_test(args)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -25,6 +25,8 @@ from pathlib import Path
import torch import torch
from lhotse import ( from lhotse import (
CutSet, CutSet,
WhisperFbank,
WhisperFbankConfig,
KaldifeatFbank, KaldifeatFbank,
KaldifeatFbankConfig, KaldifeatFbankConfig,
LilcomChunkyWriter, LilcomChunkyWriter,
@ -88,6 +90,20 @@ def get_parser():
default=-1, default=-1,
help="Stop processing pieces until this number (exclusive).", help="Stop processing pieces until this number (exclusive).",
) )
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser return parser
@ -111,6 +127,9 @@ def compute_fbank_kespeech_splits(args):
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
if args.whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
else:
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
logging.info(f"device: {device}") logging.info(f"device: {device}")

View File

@ -30,7 +30,7 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor from icefall.utils import get_executor
@ -42,8 +42,27 @@ from icefall.utils import get_executor
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False): parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser
def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
src_dir = Path("data/manifests/magicdata") src_dir = Path("data/manifests/magicdata")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(30, os.cpu_count()) num_jobs = min(30, os.cpu_count())
@ -66,6 +85,9 @@ def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False)
dataset_parts, dataset_parts,
) )
if args.whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda"))
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
@ -118,5 +140,5 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_magicdata( compute_fbank_magicdata(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
) )

View File

@ -30,7 +30,7 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor from icefall.utils import get_executor
@ -43,7 +43,7 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False): def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
src_dir = Path("data/manifests/primewords") src_dir = Path("data/manifests/primewords")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(15, os.cpu_count())
@ -66,6 +66,9 @@ def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False
dataset_parts, dataset_parts,
) )
if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
@ -108,6 +111,13 @@ def get_args():
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
) )
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser.parse_args() return parser.parse_args()
@ -118,5 +128,5 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_primewords( compute_fbank_primewords(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
) )

View File

@ -30,7 +30,7 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor from icefall.utils import get_executor
@ -43,7 +43,7 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False): def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
src_dir = Path("data/manifests/stcmds") src_dir = Path("data/manifests/stcmds")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(15, os.cpu_count())
@ -66,6 +66,9 @@ def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False):
dataset_parts, dataset_parts,
) )
if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
@ -107,6 +110,12 @@ def get_args():
default=False, default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
) )
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser.parse_args() return parser.parse_args()
@ -117,5 +126,5 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_stcmds( compute_fbank_stcmds(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
) )

View File

@ -30,7 +30,7 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor from icefall.utils import get_executor
@ -43,7 +43,7 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False): def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
src_dir = Path("data/manifests/thchs30") src_dir = Path("data/manifests/thchs30")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(15, os.cpu_count())
@ -70,6 +70,9 @@ def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False):
dataset_parts, dataset_parts,
) )
if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
@ -113,6 +116,12 @@ def get_args():
default=False, default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
) )
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser.parse_args() return parser.parse_args()
@ -123,5 +132,5 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_thchs30( compute_fbank_thchs30(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
) )

View File

@ -18,7 +18,7 @@
import logging import logging
from pathlib import Path from pathlib import Path
import argparse
import torch import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter