From e6e1f3fa4f70c7c7299d2e3074134085cbb9d78e Mon Sep 17 00:00:00 2001 From: root Date: Fri, 23 May 2025 01:53:05 -0700 Subject: [PATCH] add tts stage --- .../SPEECH2SPEECH/qwen_omni/data_module.py | 65 ++++++++++++++++++- .../SPEECH2SPEECH/qwen_omni/train.py | 3 + 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py index a52f84b0c..457c3e107 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py @@ -24,7 +24,7 @@ from pathlib import Path from typing import Any, Dict, Optional import torch -from datasets import interleave_datasets, load_dataset +from datasets import interleave_datasets, load_dataset, Audio, Features, Value, Sequence from lhotse import ( CutSet, WhisperFbank, @@ -49,7 +49,9 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader from utils import get_local_rank, str2bool - +import io +import wave +import random class _SeedWorkers: def __init__(self, seed: int): @@ -751,4 +753,61 @@ class AsrDataModule: 423_000, 28539, ], - ) \ No newline at end of file + ) + + @lru_cache() + def train_cuts_emilia_en(self) -> CutSet: + logging.info("About to get train cuts") + data_path = "/lustre/fsw/general_sa/yuekaiz/s2s" + "/emilia_en" + # if self.args.huggingface_dataset_path_or_name is not None: + # data_path = self.args.huggingface_dataset_path_or_name + "/emilia_en" + # else: + # data_path = "yuekai/emilia_en" + + emilia_en_data = load_dataset( + data_path, split="train", streaming=True + ) + + def update_wav_path(example): + sampling_rate = 16000 # From current_features + duration = 1 # seconds, arbitrary duration for random audio + num_channels = 1 # mono + sample_width = 2 # 2 bytes = 16-bit audio + + num_frames = int(duration * sampling_rate) + + # Generate random bytes for the PCM data part + # This will be random noise, but structurally valid for a WAV file + pcm_data = bytes([random.randint(0, 255) for _ in range(num_frames * num_channels * sample_width)]) + + # Create a WAV file in memory + audio_buffer = io.BytesIO() + with wave.open(audio_buffer, 'wb') as wf: + wf.setnchannels(num_channels) + wf.setsampwidth(sample_width) + wf.setframerate(sampling_rate) + wf.writeframes(pcm_data) # writeframes expects bytes + + example["wav"] = audio_buffer.getvalue() + return example + + emilia_en_data = emilia_en_data.map(update_wav_path) + current_features = Features({ + 'id': Value('string'), + 'text': Value('string'), + 'duration': Value('float'), + 'language': Value('string'), + 'dnsmos': Value('float'), + 'speech_token': Sequence(Value('int32')), + 'wav': Audio(sampling_rate=16000) + + }) + emilia_en_data = emilia_en_data.rename_column("code", "speech_token") + emilia_en_data = emilia_en_data.cast(current_features) + + emilia_en_train_cuts = CutSet.from_huggingface_dataset( + emilia_en_data, # Adjusted from instruct_s2s_train + audio_key="wav", + text_key="text", + ) + return emilia_en_train_cuts \ No newline at end of file diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py index 9554d85e4..87b8315f1 100755 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py @@ -1032,6 +1032,9 @@ def run(rank, world_size, args): elif params.dataset == "gigaspeech": train_cuts = data_module.train_cuts_gigaspeech() valid_cuts = data_module.valid_cuts_ultravox() + elif params.dataset == "emilia_en": + train_cuts = data_module.train_cuts_emilia_en() + valid_cuts = data_module.valid_cuts_emilia_en() else: raise ValueError(f"Unknown dataset: {params.dataset}")