mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
add tts stage
This commit is contained in:
parent
dd858f0cd1
commit
e6e1f3fa4f
@ -24,7 +24,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from datasets import interleave_datasets, load_dataset
|
||||
from datasets import interleave_datasets, load_dataset, Audio, Features, Value, Sequence
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
WhisperFbank,
|
||||
@ -49,7 +49,9 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
from utils import get_local_rank, str2bool
|
||||
|
||||
import io
|
||||
import wave
|
||||
import random
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
@ -751,4 +753,61 @@ class AsrDataModule:
|
||||
423_000,
|
||||
28539,
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts_emilia_en(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
data_path = "/lustre/fsw/general_sa/yuekaiz/s2s" + "/emilia_en"
|
||||
# if self.args.huggingface_dataset_path_or_name is not None:
|
||||
# data_path = self.args.huggingface_dataset_path_or_name + "/emilia_en"
|
||||
# else:
|
||||
# data_path = "yuekai/emilia_en"
|
||||
|
||||
emilia_en_data = load_dataset(
|
||||
data_path, split="train", streaming=True
|
||||
)
|
||||
|
||||
def update_wav_path(example):
|
||||
sampling_rate = 16000 # From current_features
|
||||
duration = 1 # seconds, arbitrary duration for random audio
|
||||
num_channels = 1 # mono
|
||||
sample_width = 2 # 2 bytes = 16-bit audio
|
||||
|
||||
num_frames = int(duration * sampling_rate)
|
||||
|
||||
# Generate random bytes for the PCM data part
|
||||
# This will be random noise, but structurally valid for a WAV file
|
||||
pcm_data = bytes([random.randint(0, 255) for _ in range(num_frames * num_channels * sample_width)])
|
||||
|
||||
# Create a WAV file in memory
|
||||
audio_buffer = io.BytesIO()
|
||||
with wave.open(audio_buffer, 'wb') as wf:
|
||||
wf.setnchannels(num_channels)
|
||||
wf.setsampwidth(sample_width)
|
||||
wf.setframerate(sampling_rate)
|
||||
wf.writeframes(pcm_data) # writeframes expects bytes
|
||||
|
||||
example["wav"] = audio_buffer.getvalue()
|
||||
return example
|
||||
|
||||
emilia_en_data = emilia_en_data.map(update_wav_path)
|
||||
current_features = Features({
|
||||
'id': Value('string'),
|
||||
'text': Value('string'),
|
||||
'duration': Value('float'),
|
||||
'language': Value('string'),
|
||||
'dnsmos': Value('float'),
|
||||
'speech_token': Sequence(Value('int32')),
|
||||
'wav': Audio(sampling_rate=16000)
|
||||
|
||||
})
|
||||
emilia_en_data = emilia_en_data.rename_column("code", "speech_token")
|
||||
emilia_en_data = emilia_en_data.cast(current_features)
|
||||
|
||||
emilia_en_train_cuts = CutSet.from_huggingface_dataset(
|
||||
emilia_en_data, # Adjusted from instruct_s2s_train
|
||||
audio_key="wav",
|
||||
text_key="text",
|
||||
)
|
||||
return emilia_en_train_cuts
|
@ -1032,6 +1032,9 @@ def run(rank, world_size, args):
|
||||
elif params.dataset == "gigaspeech":
|
||||
train_cuts = data_module.train_cuts_gigaspeech()
|
||||
valid_cuts = data_module.valid_cuts_ultravox()
|
||||
elif params.dataset == "emilia_en":
|
||||
train_cuts = data_module.train_cuts_emilia_en()
|
||||
valid_cuts = data_module.valid_cuts_emilia_en()
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {params.dataset}")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user