add tts stage

This commit is contained in:
root 2025-05-23 01:53:05 -07:00
parent dd858f0cd1
commit e6e1f3fa4f
2 changed files with 65 additions and 3 deletions

View File

@ -24,7 +24,7 @@ from pathlib import Path
from typing import Any, Dict, Optional
import torch
from datasets import interleave_datasets, load_dataset
from datasets import interleave_datasets, load_dataset, Audio, Features, Value, Sequence
from lhotse import (
CutSet,
WhisperFbank,
@ -49,7 +49,9 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from utils import get_local_rank, str2bool
import io
import wave
import random
class _SeedWorkers:
def __init__(self, seed: int):
@ -752,3 +754,60 @@ class AsrDataModule:
28539,
],
)
@lru_cache()
def train_cuts_emilia_en(self) -> CutSet:
logging.info("About to get train cuts")
data_path = "/lustre/fsw/general_sa/yuekaiz/s2s" + "/emilia_en"
# if self.args.huggingface_dataset_path_or_name is not None:
# data_path = self.args.huggingface_dataset_path_or_name + "/emilia_en"
# else:
# data_path = "yuekai/emilia_en"
emilia_en_data = load_dataset(
data_path, split="train", streaming=True
)
def update_wav_path(example):
sampling_rate = 16000 # From current_features
duration = 1 # seconds, arbitrary duration for random audio
num_channels = 1 # mono
sample_width = 2 # 2 bytes = 16-bit audio
num_frames = int(duration * sampling_rate)
# Generate random bytes for the PCM data part
# This will be random noise, but structurally valid for a WAV file
pcm_data = bytes([random.randint(0, 255) for _ in range(num_frames * num_channels * sample_width)])
# Create a WAV file in memory
audio_buffer = io.BytesIO()
with wave.open(audio_buffer, 'wb') as wf:
wf.setnchannels(num_channels)
wf.setsampwidth(sample_width)
wf.setframerate(sampling_rate)
wf.writeframes(pcm_data) # writeframes expects bytes
example["wav"] = audio_buffer.getvalue()
return example
emilia_en_data = emilia_en_data.map(update_wav_path)
current_features = Features({
'id': Value('string'),
'text': Value('string'),
'duration': Value('float'),
'language': Value('string'),
'dnsmos': Value('float'),
'speech_token': Sequence(Value('int32')),
'wav': Audio(sampling_rate=16000)
})
emilia_en_data = emilia_en_data.rename_column("code", "speech_token")
emilia_en_data = emilia_en_data.cast(current_features)
emilia_en_train_cuts = CutSet.from_huggingface_dataset(
emilia_en_data, # Adjusted from instruct_s2s_train
audio_key="wav",
text_key="text",
)
return emilia_en_train_cuts

View File

@ -1032,6 +1032,9 @@ def run(rank, world_size, args):
elif params.dataset == "gigaspeech":
train_cuts = data_module.train_cuts_gigaspeech()
valid_cuts = data_module.valid_cuts_ultravox()
elif params.dataset == "emilia_en":
train_cuts = data_module.train_cuts_emilia_en()
valid_cuts = data_module.valid_cuts_emilia_en()
else:
raise ValueError(f"Unknown dataset: {params.dataset}")