diff --git a/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py b/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py index f67324ba3..58d7cf3d6 100755 --- a/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py +++ b/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py @@ -35,6 +35,7 @@ from pathlib import Path import torch from datasets import load_dataset from lhotse import CutSet, LilcomChunkyWriter, WhisperFbank, WhisperFbankConfig +from vocalnet_lhotse_cutset import LazyCustomDatasetIterator from icefall.utils import str2bool @@ -105,9 +106,50 @@ def get_parser(): default="belle", help="""The dataset prefix to use when saving the features""", ) + parser.add_argument( + "--json-file-path", + type=str, + default=None, + help="The path to the json file containing the vocalnet data", + ) + parser.add_argument( + "--drop-recordings", + type=str2bool, + default=True, + help="Drop recordings. Default: False.", + ) + parser.add_argument( + "--subset", + type=str, + default=None, + help="The subset to use from the Huggingface dataset", + ) + parser.add_argument( + "--split", + type=str, + default="train", + help="The split to use from the Huggingface dataset", + ) return parser +def remove_short_and_long_utt(c): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 50.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + def compute_fbank(args): in_out_dir = Path(args.out_dir) in_out_dir.mkdir(parents=True, exist_ok=True) @@ -130,11 +172,14 @@ def compute_fbank(args): logging.info(f"device: {device}") dataset = load_dataset( - args.huggingface_dataset_path_or_name, streaming=True, split="train" + args.huggingface_dataset_path_or_name, + args.subset, + streaming=True, + split=args.split, ) num_shards = dataset.num_shards num_digits = 5 - for i in range(num_shards): + for i in range(252, num_shards): shard = dataset.shard(num_shards, i) # shard = shard.take(10) # for testing logging.info( @@ -147,6 +192,64 @@ def compute_fbank(args): shard, audio_key=args.audio_key, text_key=args.text_key ) + cut_set = cut_set.filter(remove_short_and_long_utt) + if args.resample_to_16kHz: + cut_set = cut_set.resample(16000) + if args.speed_perturb: + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + + logging.info("Computing features") + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{in_out_dir}/feats_{idx}_{args.subset}", + num_workers=num_workers, + batch_duration=batch_duration, + storage_type=LilcomChunkyWriter, + overwrite=True, + ) + # cut_set = cut_set.trim_to_supervisions( + # keep_overlapping=False, min_duration=None + # ) + cuts_path = f"{in_out_dir}/cuts_{args.prefix}.{idx}.{args.subset}.jsonl.gz" + logging.info(f"Saving to {cuts_path}") + # see https://github.com/lhotse-speech/lhotse/issues/1125 + if args.drop_recordings: + cut_set.drop_recordings().to_file(cuts_path) + else: + cut_set.to_file(cuts_path) + + +def compute_fbank_vocalnet(args): + in_out_dir = Path(args.out_dir) + in_out_dir.mkdir(parents=True, exist_ok=True) + # number of workers in dataloader + num_workers = 4 + + # number of seconds in a batch + batch_duration = 10 + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + if args.whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device=device) + ) + else: + raise NotImplementedError("Only WhisperFbank is implemented.") + + logging.info(f"device: {device}") + + num_shards = 50 + num_digits = 5 + for i in range(num_shards): + logging.info(f"Processing shard {i}") + idx = f"{i}".zfill(num_digits) + cut_set = CutSet( + LazyCustomDatasetIterator( + json_file_path=args.json_file_path, shard_id=i, num_shards=num_shards + ) + ) cut_set = cut_set.trim_to_supervisions( keep_overlapping=False, min_duration=None ) @@ -168,7 +271,7 @@ def compute_fbank(args): cuts_path = f"{in_out_dir}/cuts_{args.prefix}.{idx}.jsonl.gz" logging.info(f"Saving to {cuts_path}") # see https://github.com/lhotse-speech/lhotse/issues/1125 - cut_set.drop_recordings().to_file(cuts_path) + cut_set.to_file(cuts_path) def main(): @@ -178,8 +281,10 @@ def main(): parser = get_parser() args = parser.parse_args() logging.info(vars(args)) - - compute_fbank(args) + if args.json_file_path is not None: + compute_fbank_vocalnet(args) + else: + compute_fbank(args) if __name__ == "__main__": diff --git a/egs/speech_llm/SPEECH2SPEECH/local/vocalnet_lhotse_cutset.py b/egs/speech_llm/SPEECH2SPEECH/local/vocalnet_lhotse_cutset.py new file mode 100644 index 000000000..f7519fbfe --- /dev/null +++ b/egs/speech_llm/SPEECH2SPEECH/local/vocalnet_lhotse_cutset.py @@ -0,0 +1,99 @@ +# https://huggingface.co/datasets/VocalNet/UltraChat-vocalnet/blob/main/UltraChat.json +# https://huggingface.co/datasets/VocalNet/VoiceAssistant-430K-vocalnet/blob/main/VoiceAssistant-430K.json +import json +import os + +import numpy as np +from lhotse import CutSet +from lhotse.audio import Recording +from lhotse.supervision import SupervisionSegment + + +class LazyCustomDatasetIterator: + """ + Thin wrapper on top of HF datasets objects that allows to interact with them through a Lhotse CutSet. + It can be initialized with an existing HF dataset, or args/kwargs passed on to ``datasets.load_dataset()``. + Use ``audio_key``, ``text_key``, ``lang_key`` and ``gender_key`` options to indicate which keys in dict examples + returned from HF Dataset should be looked up for audio, transcript, language, and gender respectively. + The remaining keys in HF dataset examples will be stored inside ``cut.custom`` dictionary. + Example with existing HF dataset:: + >>> import datasets + ... dataset = datasets.load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test") + ... dataset = dataset.map(some_transform) + ... cuts_it = LazyHFDatasetIterator(dataset) + ... for cut in cuts_it: + ... pass + Example providing HF dataset init args/kwargs:: + >>> import datasets + ... cuts_it = LazyHFDatasetIterator("mozilla-foundation/common_voice_11_0", "hi", split="test") + ... for cut in cuts_it: + ... pass + """ + + def __init__(self, json_file_path: str, shard_id: int = 0, num_shards: int = 100): + self.json_file_path = json_file_path + self.shard_id = shard_id + self.num_shards = num_shards + + def __iter__(self): + + with open(self.json_file_path, "r", encoding="utf-8") as f: + list_data_dict = json.load(f) + list_data_dict = list_data_dict[self.shard_id :: self.num_shards] + for item in list_data_dict: + custom_data = item.copy() + json_file_parent_of_parent_dir = os.path.dirname( + os.path.dirname(self.json_file_path) + ) + units_path = os.path.join( + json_file_parent_of_parent_dir, custom_data["units"] + ) + speech_token_dict = np.load(units_path, allow_pickle=True).item() + speech_token = speech_token_dict["speech_token"].squeeze(0).tolist() + speech_token_len = speech_token_dict["speech_token_len"] + + assert len(speech_token) == speech_token_len + custom_data["speech_token"] = speech_token + audio_path = custom_data.pop("speech", None) + audio_path = os.path.join(json_file_parent_of_parent_dir, audio_path) + item_id = item.get("id") + recording = Recording.from_file(path=audio_path, recording_id=item_id) + + conversations = item.get("conversations") + assert isinstance(conversations, list) and len(conversations) == 2 + for conv in conversations: + if isinstance(conv, dict) and conv.get("from") == "gpt": + gpt_text = conv.get("value") + break + assert gpt_text is not None + + supervision = SupervisionSegment( + id=item_id, + recording_id=recording.id, + start=0.0, # Assuming the supervision covers the entire recording + duration=recording.duration, + text=gpt_text, + ) + + cut = recording.to_cut() + # cut.id will be the same as recording.id + + cut.supervisions = [supervision] + # custom_data contains the original item's fields, minus "speech". + # So, "id", "conversations", "units", etc., are preserved here. + custom_data.pop("conversations") + custom_data.pop("units") + cut.custom = custom_data + + yield cut + + +if __name__ == "__main__": + json_file_path = ( + "/workspace/slam/VoiceAssistant-430K-vocalnet/VoiceAssistant-430K.json" + ) + cut_set = CutSet(LazyCustomDatasetIterator(json_file_path=json_file_path)) + + for cut in cut_set: + print(cut) + input() diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh index 42c9b4eaa..cff7a45fa 100644 --- a/egs/speech_llm/SPEECH2SPEECH/prepare.sh +++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh @@ -120,3 +120,56 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then --asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \ --use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "stage 1: Compute fbank feature from huggingface" + # CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \ + # --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \ + # --out-dir data/fbank_voice_assistant \ + # --huggingface-dataset-path-or-name worstchan/VoiceAssistant-400K-SLAM-Omni \ + # --audio-key question_audio --text-key answer \ + # --prefix voice_assistant + CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \ + --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \ + --out-dir data/fbank_voice_assistant_cosy2 \ + --json-file-path /workspace/slam/VoiceAssistant-430K-vocalnet/VoiceAssistant-430K.json \ + --prefix voice_assistant +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "stage 7: Compute fbank feature from huggingface" + # CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \ + # --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \ + # --out-dir data/fbank_ultrachat \ + # --huggingface-dataset-path-or-name worstchan/UltraChat-300K-SLAM-Omni \ + # --audio-key question_audio --text-key answer \ + # --prefix ultrachat + CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \ + --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \ + --out-dir data/fbank_ultrachat_cosy2 \ + --json-file-path /workspace/slam/UltraChat-vocalnet/UltraChat.json \ + --prefix ultrachat +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "stage 8: Compute fbank feature from huggingface" + + CUDA_VISIBLE_DEVICES=1 python3 local/compute_whisper_fbank.py \ + --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \ + --out-dir data/fbank_gigaspeech \ + --huggingface-dataset-path-or-name speechcolab/gigaspeech \ + --subset test --split test \ + --audio-key audio --text-key text \ + --prefix gigaspeech +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "stage 9: Compute fbank feature from huggingface" + CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \ + --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb True \ + --out-dir data/fbank_gigaspeech \ + --huggingface-dataset-path-or-name speechcolab/gigaspeech \ + --subset xl --split train \ + --audio-key audio --text-key text \ + --prefix gigaspeech +fi