mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
add tts stage
This commit is contained in:
parent
dd858f0cd1
commit
e6e1f3fa4f
@ -24,7 +24,7 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import interleave_datasets, load_dataset
|
from datasets import interleave_datasets, load_dataset, Audio, Features, Value, Sequence
|
||||||
from lhotse import (
|
from lhotse import (
|
||||||
CutSet,
|
CutSet,
|
||||||
WhisperFbank,
|
WhisperFbank,
|
||||||
@ -49,7 +49,9 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
|||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from utils import get_local_rank, str2bool
|
from utils import get_local_rank, str2bool
|
||||||
|
import io
|
||||||
|
import wave
|
||||||
|
import random
|
||||||
|
|
||||||
class _SeedWorkers:
|
class _SeedWorkers:
|
||||||
def __init__(self, seed: int):
|
def __init__(self, seed: int):
|
||||||
@ -752,3 +754,60 @@ class AsrDataModule:
|
|||||||
28539,
|
28539,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_cuts_emilia_en(self) -> CutSet:
|
||||||
|
logging.info("About to get train cuts")
|
||||||
|
data_path = "/lustre/fsw/general_sa/yuekaiz/s2s" + "/emilia_en"
|
||||||
|
# if self.args.huggingface_dataset_path_or_name is not None:
|
||||||
|
# data_path = self.args.huggingface_dataset_path_or_name + "/emilia_en"
|
||||||
|
# else:
|
||||||
|
# data_path = "yuekai/emilia_en"
|
||||||
|
|
||||||
|
emilia_en_data = load_dataset(
|
||||||
|
data_path, split="train", streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_wav_path(example):
|
||||||
|
sampling_rate = 16000 # From current_features
|
||||||
|
duration = 1 # seconds, arbitrary duration for random audio
|
||||||
|
num_channels = 1 # mono
|
||||||
|
sample_width = 2 # 2 bytes = 16-bit audio
|
||||||
|
|
||||||
|
num_frames = int(duration * sampling_rate)
|
||||||
|
|
||||||
|
# Generate random bytes for the PCM data part
|
||||||
|
# This will be random noise, but structurally valid for a WAV file
|
||||||
|
pcm_data = bytes([random.randint(0, 255) for _ in range(num_frames * num_channels * sample_width)])
|
||||||
|
|
||||||
|
# Create a WAV file in memory
|
||||||
|
audio_buffer = io.BytesIO()
|
||||||
|
with wave.open(audio_buffer, 'wb') as wf:
|
||||||
|
wf.setnchannels(num_channels)
|
||||||
|
wf.setsampwidth(sample_width)
|
||||||
|
wf.setframerate(sampling_rate)
|
||||||
|
wf.writeframes(pcm_data) # writeframes expects bytes
|
||||||
|
|
||||||
|
example["wav"] = audio_buffer.getvalue()
|
||||||
|
return example
|
||||||
|
|
||||||
|
emilia_en_data = emilia_en_data.map(update_wav_path)
|
||||||
|
current_features = Features({
|
||||||
|
'id': Value('string'),
|
||||||
|
'text': Value('string'),
|
||||||
|
'duration': Value('float'),
|
||||||
|
'language': Value('string'),
|
||||||
|
'dnsmos': Value('float'),
|
||||||
|
'speech_token': Sequence(Value('int32')),
|
||||||
|
'wav': Audio(sampling_rate=16000)
|
||||||
|
|
||||||
|
})
|
||||||
|
emilia_en_data = emilia_en_data.rename_column("code", "speech_token")
|
||||||
|
emilia_en_data = emilia_en_data.cast(current_features)
|
||||||
|
|
||||||
|
emilia_en_train_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
emilia_en_data, # Adjusted from instruct_s2s_train
|
||||||
|
audio_key="wav",
|
||||||
|
text_key="text",
|
||||||
|
)
|
||||||
|
return emilia_en_train_cuts
|
@ -1032,6 +1032,9 @@ def run(rank, world_size, args):
|
|||||||
elif params.dataset == "gigaspeech":
|
elif params.dataset == "gigaspeech":
|
||||||
train_cuts = data_module.train_cuts_gigaspeech()
|
train_cuts = data_module.train_cuts_gigaspeech()
|
||||||
valid_cuts = data_module.valid_cuts_ultravox()
|
valid_cuts = data_module.valid_cuts_ultravox()
|
||||||
|
elif params.dataset == "emilia_en":
|
||||||
|
train_cuts = data_module.train_cuts_emilia_en()
|
||||||
|
valid_cuts = data_module.valid_cuts_emilia_en()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown dataset: {params.dataset}")
|
raise ValueError(f"Unknown dataset: {params.dataset}")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user