mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
add whisper fbank for other dataset
This commit is contained in:
parent
e43c4da91d
commit
315175a362
@ -19,9 +19,10 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import argparse
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
|
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
@ -30,8 +31,27 @@ from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWri
|
|||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
def compute_fbank_kespeech_dev_test():
|
parser.add_argument(
|
||||||
|
"--num-mel-bins",
|
||||||
|
type=int,
|
||||||
|
default=80,
|
||||||
|
help="""The number of mel bins for Fbank""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-fbank",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def compute_fbank_kespeech_dev_test(args):
|
||||||
in_out_dir = Path("data/fbank/kespeech")
|
in_out_dir = Path("data/fbank/kespeech")
|
||||||
# number of workers in dataloader
|
# number of workers in dataloader
|
||||||
num_workers = 42
|
num_workers = 42
|
||||||
@ -48,7 +68,10 @@ def compute_fbank_kespeech_dev_test():
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", 0)
|
device = torch.device("cuda", 0)
|
||||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
if args.whisper_fbank:
|
||||||
|
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device=device))
|
||||||
|
else:
|
||||||
|
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
@ -86,7 +109,11 @@ def main():
|
|||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
compute_fbank_kespeech_dev_test()
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
compute_fbank_kespeech_dev_test(args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -25,6 +25,8 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
from lhotse import (
|
from lhotse import (
|
||||||
CutSet,
|
CutSet,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
KaldifeatFbank,
|
KaldifeatFbank,
|
||||||
KaldifeatFbankConfig,
|
KaldifeatFbankConfig,
|
||||||
LilcomChunkyWriter,
|
LilcomChunkyWriter,
|
||||||
@ -88,6 +90,20 @@ def get_parser():
|
|||||||
default=-1,
|
default=-1,
|
||||||
help="Stop processing pieces until this number (exclusive).",
|
help="Stop processing pieces until this number (exclusive).",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-mel-bins",
|
||||||
|
type=int,
|
||||||
|
default=80,
|
||||||
|
help="""The number of mel bins for Fbank""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-fbank",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -111,7 +127,10 @@ def compute_fbank_kespeech_splits(args):
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", 0)
|
device = torch.device("cuda", 0)
|
||||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
if args.whisper_fbank:
|
||||||
|
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||||
|
else:
|
||||||
|
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance
|
set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance
|
||||||
|
@ -30,7 +30,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
@ -42,8 +42,27 @@ from icefall.utils import get_executor
|
|||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False):
|
parser.add_argument(
|
||||||
|
"--num-mel-bins",
|
||||||
|
type=int,
|
||||||
|
default=80,
|
||||||
|
help="""The number of mel bins for Fbank""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-fbank",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
||||||
src_dir = Path("data/manifests/magicdata")
|
src_dir = Path("data/manifests/magicdata")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(30, os.cpu_count())
|
num_jobs = min(30, os.cpu_count())
|
||||||
@ -66,7 +85,10 @@ def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False)
|
|||||||
dataset_parts,
|
dataset_parts,
|
||||||
)
|
)
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
if args.whisper_fbank:
|
||||||
|
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda"))
|
||||||
|
else:
|
||||||
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
for partition, m in manifests.items():
|
for partition, m in manifests.items():
|
||||||
@ -118,5 +140,5 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_magicdata(
|
compute_fbank_magicdata(
|
||||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
|
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
||||||
)
|
)
|
||||||
|
@ -30,7 +30,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
@ -43,7 +43,7 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False):
|
def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
||||||
src_dir = Path("data/manifests/primewords")
|
src_dir = Path("data/manifests/primewords")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -66,7 +66,10 @@ def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False
|
|||||||
dataset_parts,
|
dataset_parts,
|
||||||
)
|
)
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
if whisper_fbank:
|
||||||
|
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||||
|
else:
|
||||||
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
for partition, m in manifests.items():
|
for partition, m in manifests.items():
|
||||||
@ -108,6 +111,13 @@ def get_args():
|
|||||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-fbank",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -118,5 +128,5 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_primewords(
|
compute_fbank_primewords(
|
||||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
|
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
||||||
)
|
)
|
||||||
|
@ -30,7 +30,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
@ -43,7 +43,7 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False):
|
def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
||||||
src_dir = Path("data/manifests/stcmds")
|
src_dir = Path("data/manifests/stcmds")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -66,7 +66,10 @@ def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False):
|
|||||||
dataset_parts,
|
dataset_parts,
|
||||||
)
|
)
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
if whisper_fbank:
|
||||||
|
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||||
|
else:
|
||||||
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
for partition, m in manifests.items():
|
for partition, m in manifests.items():
|
||||||
@ -107,6 +110,12 @@ def get_args():
|
|||||||
default=False,
|
default=False,
|
||||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-fbank",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -117,5 +126,5 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_stcmds(
|
compute_fbank_stcmds(
|
||||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
|
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
||||||
)
|
)
|
||||||
|
@ -30,7 +30,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
@ -43,7 +43,7 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False):
|
def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
||||||
src_dir = Path("data/manifests/thchs30")
|
src_dir = Path("data/manifests/thchs30")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -70,7 +70,10 @@ def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False):
|
|||||||
dataset_parts,
|
dataset_parts,
|
||||||
)
|
)
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
if whisper_fbank:
|
||||||
|
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||||
|
else:
|
||||||
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
for partition, m in manifests.items():
|
for partition, m in manifests.items():
|
||||||
@ -113,6 +116,12 @@ def get_args():
|
|||||||
default=False,
|
default=False,
|
||||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-fbank",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -123,5 +132,5 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_thchs30(
|
compute_fbank_thchs30(
|
||||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
|
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
||||||
)
|
)
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
|
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user