mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
add whisper fbank for other dataset
This commit is contained in:
parent
e43c4da91d
commit
315175a362
@ -19,9 +19,10 @@
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -30,8 +31,27 @@ from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWri
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
def compute_fbank_kespeech_dev_test():
|
||||
parser.add_argument(
|
||||
"--num-mel-bins",
|
||||
type=int,
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser
|
||||
|
||||
def compute_fbank_kespeech_dev_test(args):
|
||||
in_out_dir = Path("data/fbank/kespeech")
|
||||
# number of workers in dataloader
|
||||
num_workers = 42
|
||||
@ -48,7 +68,10 @@ def compute_fbank_kespeech_dev_test():
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
if args.whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device=device))
|
||||
else:
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
@ -86,7 +109,11 @@ def main():
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_kespeech_dev_test()
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
compute_fbank_kespeech_dev_test(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -25,6 +25,8 @@ from pathlib import Path
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
@ -88,6 +90,20 @@ def get_parser():
|
||||
default=-1,
|
||||
help="Stop processing pieces until this number (exclusive).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-mel-bins",
|
||||
type=int,
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@ -111,7 +127,10 @@ def compute_fbank_kespeech_splits(args):
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
if args.whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||
else:
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance
|
||||
|
@ -30,7 +30,7 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
@ -42,8 +42,27 @@ from icefall.utils import get_executor
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False):
|
||||
parser.add_argument(
|
||||
"--num-mel-bins",
|
||||
type=int,
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser
|
||||
|
||||
def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
||||
src_dir = Path("data/manifests/magicdata")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(30, os.cpu_count())
|
||||
@ -65,8 +84,11 @@ def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False)
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
if args.whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda"))
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -118,5 +140,5 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_magicdata(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
||||
)
|
||||
|
@ -30,7 +30,7 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
@ -43,7 +43,7 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False):
|
||||
def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
||||
src_dir = Path("data/manifests/primewords")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -65,8 +65,11 @@ def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -108,6 +111,13 @@ def get_args():
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -118,5 +128,5 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_primewords(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
||||
)
|
||||
|
@ -30,7 +30,7 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
@ -43,7 +43,7 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False):
|
||||
def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
||||
src_dir = Path("data/manifests/stcmds")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -66,7 +66,10 @@ def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False):
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -107,6 +110,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -117,5 +126,5 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_stcmds(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
||||
)
|
||||
|
@ -30,7 +30,7 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
@ -43,7 +43,7 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False):
|
||||
def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
||||
src_dir = Path("data/manifests/thchs30")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -70,7 +70,10 @@ def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False):
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -113,6 +116,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -123,5 +132,5 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_thchs30(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
||||
)
|
||||
|
@ -18,7 +18,7 @@
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user