add vocalnet en data

This commit is contained in:
root 2025-05-08 06:29:46 +00:00
parent 08be51a91f
commit 2dd40b62ef
3 changed files with 262 additions and 5 deletions

View File

@ -35,6 +35,7 @@ from pathlib import Path
import torch
from datasets import load_dataset
from lhotse import CutSet, LilcomChunkyWriter, WhisperFbank, WhisperFbankConfig
from vocalnet_lhotse_cutset import LazyCustomDatasetIterator
from icefall.utils import str2bool
@ -105,9 +106,50 @@ def get_parser():
default="belle",
help="""The dataset prefix to use when saving the features""",
)
parser.add_argument(
"--json-file-path",
type=str,
default=None,
help="The path to the json file containing the vocalnet data",
)
parser.add_argument(
"--drop-recordings",
type=str2bool,
default=True,
help="Drop recordings. Default: False.",
)
parser.add_argument(
"--subset",
type=str,
default=None,
help="The subset to use from the Huggingface dataset",
)
parser.add_argument(
"--split",
type=str,
default="train",
help="The split to use from the Huggingface dataset",
)
return parser
def remove_short_and_long_utt(c):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 50.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False
return True
def compute_fbank(args):
in_out_dir = Path(args.out_dir)
in_out_dir.mkdir(parents=True, exist_ok=True)
@ -130,11 +172,14 @@ def compute_fbank(args):
logging.info(f"device: {device}")
dataset = load_dataset(
args.huggingface_dataset_path_or_name, streaming=True, split="train"
args.huggingface_dataset_path_or_name,
args.subset,
streaming=True,
split=args.split,
)
num_shards = dataset.num_shards
num_digits = 5
for i in range(num_shards):
for i in range(252, num_shards):
shard = dataset.shard(num_shards, i)
# shard = shard.take(10) # for testing
logging.info(
@ -147,6 +192,64 @@ def compute_fbank(args):
shard, audio_key=args.audio_key, text_key=args.text_key
)
cut_set = cut_set.filter(remove_short_and_long_utt)
if args.resample_to_16kHz:
cut_set = cut_set.resample(16000)
if args.speed_perturb:
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
logging.info("Computing features")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{in_out_dir}/feats_{idx}_{args.subset}",
num_workers=num_workers,
batch_duration=batch_duration,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
# cut_set = cut_set.trim_to_supervisions(
# keep_overlapping=False, min_duration=None
# )
cuts_path = f"{in_out_dir}/cuts_{args.prefix}.{idx}.{args.subset}.jsonl.gz"
logging.info(f"Saving to {cuts_path}")
# see https://github.com/lhotse-speech/lhotse/issues/1125
if args.drop_recordings:
cut_set.drop_recordings().to_file(cuts_path)
else:
cut_set.to_file(cuts_path)
def compute_fbank_vocalnet(args):
in_out_dir = Path(args.out_dir)
in_out_dir.mkdir(parents=True, exist_ok=True)
# number of workers in dataloader
num_workers = 4
# number of seconds in a batch
batch_duration = 10
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
if args.whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
)
else:
raise NotImplementedError("Only WhisperFbank is implemented.")
logging.info(f"device: {device}")
num_shards = 50
num_digits = 5
for i in range(num_shards):
logging.info(f"Processing shard {i}")
idx = f"{i}".zfill(num_digits)
cut_set = CutSet(
LazyCustomDatasetIterator(
json_file_path=args.json_file_path, shard_id=i, num_shards=num_shards
)
)
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
@ -168,7 +271,7 @@ def compute_fbank(args):
cuts_path = f"{in_out_dir}/cuts_{args.prefix}.{idx}.jsonl.gz"
logging.info(f"Saving to {cuts_path}")
# see https://github.com/lhotse-speech/lhotse/issues/1125
cut_set.drop_recordings().to_file(cuts_path)
cut_set.to_file(cuts_path)
def main():
@ -178,8 +281,10 @@ def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
compute_fbank(args)
if args.json_file_path is not None:
compute_fbank_vocalnet(args)
else:
compute_fbank(args)
if __name__ == "__main__":

View File

@ -0,0 +1,99 @@
# https://huggingface.co/datasets/VocalNet/UltraChat-vocalnet/blob/main/UltraChat.json
# https://huggingface.co/datasets/VocalNet/VoiceAssistant-430K-vocalnet/blob/main/VoiceAssistant-430K.json
import json
import os
import numpy as np
from lhotse import CutSet
from lhotse.audio import Recording
from lhotse.supervision import SupervisionSegment
class LazyCustomDatasetIterator:
"""
Thin wrapper on top of HF datasets objects that allows to interact with them through a Lhotse CutSet.
It can be initialized with an existing HF dataset, or args/kwargs passed on to ``datasets.load_dataset()``.
Use ``audio_key``, ``text_key``, ``lang_key`` and ``gender_key`` options to indicate which keys in dict examples
returned from HF Dataset should be looked up for audio, transcript, language, and gender respectively.
The remaining keys in HF dataset examples will be stored inside ``cut.custom`` dictionary.
Example with existing HF dataset::
>>> import datasets
... dataset = datasets.load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test")
... dataset = dataset.map(some_transform)
... cuts_it = LazyHFDatasetIterator(dataset)
... for cut in cuts_it:
... pass
Example providing HF dataset init args/kwargs::
>>> import datasets
... cuts_it = LazyHFDatasetIterator("mozilla-foundation/common_voice_11_0", "hi", split="test")
... for cut in cuts_it:
... pass
"""
def __init__(self, json_file_path: str, shard_id: int = 0, num_shards: int = 100):
self.json_file_path = json_file_path
self.shard_id = shard_id
self.num_shards = num_shards
def __iter__(self):
with open(self.json_file_path, "r", encoding="utf-8") as f:
list_data_dict = json.load(f)
list_data_dict = list_data_dict[self.shard_id :: self.num_shards]
for item in list_data_dict:
custom_data = item.copy()
json_file_parent_of_parent_dir = os.path.dirname(
os.path.dirname(self.json_file_path)
)
units_path = os.path.join(
json_file_parent_of_parent_dir, custom_data["units"]
)
speech_token_dict = np.load(units_path, allow_pickle=True).item()
speech_token = speech_token_dict["speech_token"].squeeze(0).tolist()
speech_token_len = speech_token_dict["speech_token_len"]
assert len(speech_token) == speech_token_len
custom_data["speech_token"] = speech_token
audio_path = custom_data.pop("speech", None)
audio_path = os.path.join(json_file_parent_of_parent_dir, audio_path)
item_id = item.get("id")
recording = Recording.from_file(path=audio_path, recording_id=item_id)
conversations = item.get("conversations")
assert isinstance(conversations, list) and len(conversations) == 2
for conv in conversations:
if isinstance(conv, dict) and conv.get("from") == "gpt":
gpt_text = conv.get("value")
break
assert gpt_text is not None
supervision = SupervisionSegment(
id=item_id,
recording_id=recording.id,
start=0.0, # Assuming the supervision covers the entire recording
duration=recording.duration,
text=gpt_text,
)
cut = recording.to_cut()
# cut.id will be the same as recording.id
cut.supervisions = [supervision]
# custom_data contains the original item's fields, minus "speech".
# So, "id", "conversations", "units", etc., are preserved here.
custom_data.pop("conversations")
custom_data.pop("units")
cut.custom = custom_data
yield cut
if __name__ == "__main__":
json_file_path = (
"/workspace/slam/VoiceAssistant-430K-vocalnet/VoiceAssistant-430K.json"
)
cut_set = CutSet(LazyCustomDatasetIterator(json_file_path=json_file_path))
for cut in cut_set:
print(cut)
input()

View File

@ -120,3 +120,56 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
--asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \
--use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "stage 1: Compute fbank feature from huggingface"
# CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
# --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
# --out-dir data/fbank_voice_assistant \
# --huggingface-dataset-path-or-name worstchan/VoiceAssistant-400K-SLAM-Omni \
# --audio-key question_audio --text-key answer \
# --prefix voice_assistant
CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_voice_assistant_cosy2 \
--json-file-path /workspace/slam/VoiceAssistant-430K-vocalnet/VoiceAssistant-430K.json \
--prefix voice_assistant
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "stage 7: Compute fbank feature from huggingface"
# CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \
# --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
# --out-dir data/fbank_ultrachat \
# --huggingface-dataset-path-or-name worstchan/UltraChat-300K-SLAM-Omni \
# --audio-key question_audio --text-key answer \
# --prefix ultrachat
CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_ultrachat_cosy2 \
--json-file-path /workspace/slam/UltraChat-vocalnet/UltraChat.json \
--prefix ultrachat
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "stage 8: Compute fbank feature from huggingface"
CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_gigaspeech \
--huggingface-dataset-path-or-name speechcolab/gigaspeech \
--subset test --split test \
--audio-key audio --text-key text \
--prefix gigaspeech
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "stage 9: Compute fbank feature from huggingface"
CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb True \
--out-dir data/fbank_gigaspeech \
--huggingface-dataset-path-or-name speechcolab/gigaspeech \
--subset xl --split train \
--audio-key audio --text-key text \
--prefix gigaspeech
fi