mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge 1653b76deb5bffec958d17cf5440ace4f776732f into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9
This commit is contained in:
commit
5c73a91e72
76
egs/emilia/TTS/README.md
Normal file
76
egs/emilia/TTS/README.md
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# Results
|
||||||
|
| LLM Model | Flow matching Model | Seed-TTS test_zh CER | Comment |
|
||||||
|
|---------------------------------------|----------|-----------|--------|
|
||||||
|
| pretrained cosyvoice2 llm | pretrained cosyvoice2 unet | 1.45% | See [paper](https://arxiv.org/abs/2412.10117)|
|
||||||
|
| pretrained cosyvoice2 llm | f5-tts-small (wenetspeech4tts) | 1.79% (16 steps) | See [PR](https://github.com/k2-fsa/icefall/pull/1880)|
|
||||||
|
| llasa_cosyvoice2_token llm (Emilia 50k hours ZH) | f5-tts-small (wenetspeech4tts) | 1.81% (16 steps) | |
|
||||||
|
|
||||||
|
# Introduction
|
||||||
|
|
||||||
|
[**Emilia**](https://huggingface.co/datasets/amphion/Emilia-Dataset) starts with over 101k
|
||||||
|
hours of speech across six languages, covering a wide range of speaking styles to enable more natural and spontaneous speech generation.
|
||||||
|
|
||||||
|
See https://arxiv.org/pdf/2407.05361.
|
||||||
|
|
||||||
|
# Llasa (cosyvoice2 token)
|
||||||
|
|
||||||
|
./llasa_cosyvoice2_token contains the code for training qwen2.5-0.5b models to predict cosyvoice2 semantic tokens.
|
||||||
|
|
||||||
|
Generated samples and training logs of [Emilia](https://huggingface.co/datasets/amphion/Emilia-Dataset) 50k hours Chinese data can be found [here](https://huggingface.co/yuekai/llasa_cosyvoice2_token_qwen_0.5b/tree/main).
|
||||||
|
|
||||||
|
Preparation:
|
||||||
|
|
||||||
|
```
|
||||||
|
# extract cosyvoice2 semantic tokens
|
||||||
|
bash prepare.sh --stage 3 --stop_stage 4
|
||||||
|
|
||||||
|
# Or you could use the prepared tokens.
|
||||||
|
huggingface-cli download yuekai/emilia_cosyvoice_v2_token --local-dir emilia_cosyvoice_v2_token
|
||||||
|
```
|
||||||
|
|
||||||
|
The training command is given below:
|
||||||
|
|
||||||
|
```
|
||||||
|
# docker: ghcr.io/swivid/f5-tts:main
|
||||||
|
# pip install -r llasa_cosyvoice2_token/requirements.txt
|
||||||
|
|
||||||
|
WANDB_KEY=$your_wandb_key
|
||||||
|
wandb login ${WANDB_KEY}
|
||||||
|
huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir Qwen2.5-0.5B-Instruct
|
||||||
|
torchrun --nproc_per_node=8 train.py config.json
|
||||||
|
```
|
||||||
|
|
||||||
|
To inference with Icefall Emilia trained Chinese Llasa_cosyvoice2_token model, we need to use cosyvoice2 token flow matching [model](https://github.com/k2-fsa/icefall/pull/1880):
|
||||||
|
```
|
||||||
|
cd icefall/egs/wenetspeech4tts/TTS
|
||||||
|
huggingface-cli login
|
||||||
|
huggingface-cli download --local-dir ${exp_dir} yuekai/llasa_cosyvoice2_token_qwen_0.5b
|
||||||
|
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
|
||||||
|
vocoder=./bigvgan_v2_24khz_100band_256x
|
||||||
|
split=test_zh
|
||||||
|
llm_path=llasa_cosyvoice2_token_qwen_0.5b/checkpoint-800000
|
||||||
|
|
||||||
|
huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
|
||||||
|
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
|
||||||
|
torchrun --nproc_per_node=2 \
|
||||||
|
f5-tts/infer_dist.py \
|
||||||
|
--output_dir $output_dir \
|
||||||
|
--batch_size 1 \
|
||||||
|
--num_workers 2 \
|
||||||
|
--llm-model-name-or-path $llm_path \
|
||||||
|
--flow-matching-model-path $model_path \
|
||||||
|
--decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
|
||||||
|
--use-cosyvoice-semantic-token True \
|
||||||
|
--vocoder-dir $vocoder \
|
||||||
|
--split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \
|
||||||
|
--tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct
|
||||||
|
# compute cer
|
||||||
|
huggingface-cli download yuekai/seed_tts_eval --local-dir seed_tts_eval --repo-type dataset
|
||||||
|
manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst
|
||||||
|
bash local/compute_wer.sh $output_dir $manifest
|
||||||
|
```
|
||||||
|
|
||||||
|
# Credits
|
||||||
|
- [Llasa](https://arxiv.org/abs/2502.04128)
|
||||||
|
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
|
||||||
|
- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer/tree/main)
|
27
egs/emilia/TTS/llasa_cosyvoice2_token/config.json
Normal file
27
egs/emilia/TTS/llasa_cosyvoice2_token/config.json
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
{
|
||||||
|
"llm_model_name_or_path": "./Qwen2.5-0.5B-Instruct",
|
||||||
|
"data_path": ["./emilia_cosyvoice_v2_token/cosy_v2_tokens_ZH.jsonl"],
|
||||||
|
"bf16": false,
|
||||||
|
"output_dir": "./exp_zh",
|
||||||
|
"num_train_epochs": 3,
|
||||||
|
"per_device_train_batch_size": 8,
|
||||||
|
"per_device_eval_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"evaluation_strategy": "steps",
|
||||||
|
"eval_steps": 1000,
|
||||||
|
"save_strategy": "steps",
|
||||||
|
"save_steps": 5000,
|
||||||
|
"save_total_limit": 100,
|
||||||
|
"learning_rate": 0.00005,
|
||||||
|
"weight_decay": 0.01,
|
||||||
|
"adam_beta2": 0.95,
|
||||||
|
"warmup_ratio": 0.03,
|
||||||
|
"lr_scheduler_type": "cosine",
|
||||||
|
"logging_steps": 100,
|
||||||
|
"report_to": "wandb",
|
||||||
|
"model_max_length": 2048,
|
||||||
|
"gradient_checkpointing": false,
|
||||||
|
"dataloader_num_workers": 4,
|
||||||
|
"dataloader_prefetch_factor": 4,
|
||||||
|
"deepspeed": "ds_config_zero2.json"
|
||||||
|
}
|
47
egs/emilia/TTS/llasa_cosyvoice2_token/ds_config_zero2.json
Normal file
47
egs/emilia/TTS/llasa_cosyvoice2_token/ds_config_zero2.json
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
{
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"loss_scale": 0,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"initial_scale_power": 64,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"params": {
|
||||||
|
"lr": "auto",
|
||||||
|
"betas": "auto",
|
||||||
|
"eps": "auto",
|
||||||
|
"weight_decay": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupCosineLR",
|
||||||
|
"params": {
|
||||||
|
"total_num_steps": "auto",
|
||||||
|
"warmup_min_ratio": 0.03,
|
||||||
|
"warmup_num_steps": "auto",
|
||||||
|
"cos_min_ratio": 0.1
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 2,
|
||||||
|
"overlap_comm": false,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"sub_group_size": 1e9,
|
||||||
|
"reduce_bucket_size": "auto"
|
||||||
|
},
|
||||||
|
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"gradient_clipping": 1.0,
|
||||||
|
"steps_per_print": 100,
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
8
egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt
Normal file
8
egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
torch
|
||||||
|
transformers
|
||||||
|
wandb
|
||||||
|
datasets
|
||||||
|
accelerate>=0.26.0
|
||||||
|
deepspeed
|
||||||
|
flash-attn
|
||||||
|
s3tokenizer
|
184
egs/emilia/TTS/llasa_cosyvoice2_token/train.py
Normal file
184
egs/emilia/TTS/llasa_cosyvoice2_token/train.py
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
# Modified from https://github.com/zhenye234/LLaSA_training/blob/main/train_tts.py
|
||||||
|
""" Example Usage
|
||||||
|
WANDB_KEY=$your_wandb_key
|
||||||
|
wandb login ${WANDB_KEY}
|
||||||
|
huggingface-cli download yuekai/emilia_cosyvoice_v2_token --local-dir emilia_cosyvoice_v2_token
|
||||||
|
huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir Qwen2.5-0.5B-Instruct
|
||||||
|
torchrun --nproc_per_node=8 train.py config.json
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from functools import partial
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import transformers
|
||||||
|
import wandb
|
||||||
|
from datasets import load_dataset
|
||||||
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
DataCollatorForSeq2Seq,
|
||||||
|
HfArgumentParser,
|
||||||
|
Trainer,
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
|
from transformers.trainer_pt_utils import LabelSmoother
|
||||||
|
|
||||||
|
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||||
|
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArguments:
|
||||||
|
llm_model_name_or_path: Optional[str] = field(
|
||||||
|
default="meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataArguments:
|
||||||
|
data_path: List[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Root path(s) to the data. Can be single path or list."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CustomTrainingArguments(TrainingArguments):
|
||||||
|
optim: str = field(default="adamw_torch_fused")
|
||||||
|
model_max_length: int = field(
|
||||||
|
default=2048,
|
||||||
|
metadata={"help": "Maximum sequence length"},
|
||||||
|
)
|
||||||
|
logging_steps: int = field(default=100, metadata={"help": "Log every X updates"})
|
||||||
|
report_to: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The integration to report the results and logs to."},
|
||||||
|
)
|
||||||
|
run_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "The name of the run for logging."}
|
||||||
|
)
|
||||||
|
gradient_checkpointing: bool = field(default=False)
|
||||||
|
lr_scheduler_type: str = field(
|
||||||
|
default="cosine", metadata={"help": "The learning rate scheduler to use."}
|
||||||
|
)
|
||||||
|
remove_unused_columns: bool = field(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
def data_collator(batch, tokenizer, original_tokenizer_vocab_size, cut_off_len=2048):
|
||||||
|
speech_generation_start_index = tokenizer.convert_tokens_to_ids(
|
||||||
|
"<|SPEECH_GENERATION_START|>"
|
||||||
|
)
|
||||||
|
assistant_index = tokenizer.convert_tokens_to_ids("assistant")
|
||||||
|
input_ids_list = []
|
||||||
|
for i, item in enumerate(batch):
|
||||||
|
text, code = item["text"], item["code"]
|
||||||
|
message = [
|
||||||
|
{"role": "user", "content": f"Convert the text to speech: {text}"},
|
||||||
|
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"},
|
||||||
|
]
|
||||||
|
|
||||||
|
input_ids = tokenizer.apply_chat_template(
|
||||||
|
message,
|
||||||
|
tokenize=True,
|
||||||
|
chat_template=TEMPLATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
code = [c + original_tokenizer_vocab_size for c in code]
|
||||||
|
|
||||||
|
idx = input_ids.index(speech_generation_start_index)
|
||||||
|
input_ids = input_ids[:idx] + code + input_ids[idx + 1 :]
|
||||||
|
if len(input_ids) < cut_off_len:
|
||||||
|
input_ids_list.append(input_ids)
|
||||||
|
|
||||||
|
max_len = max([len(input_ids) for input_ids in input_ids_list])
|
||||||
|
input_ids_list = [
|
||||||
|
input_ids + [tokenizer.pad_token_id] * (max_len - len(input_ids))
|
||||||
|
for input_ids in input_ids_list
|
||||||
|
]
|
||||||
|
input_ids = torch.tensor(input_ids_list, dtype=torch.int)
|
||||||
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
target_ids = input_ids.clone()
|
||||||
|
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
||||||
|
mask_indices = torch.where(input_ids == assistant_index)
|
||||||
|
for i in range(mask_indices[0].size(0)):
|
||||||
|
row = mask_indices[0][i]
|
||||||
|
col = mask_indices[1][i]
|
||||||
|
# + 2 to skip: 'assistant', '\n'
|
||||||
|
target_ids[row, : col + 2] = IGNORE_TOKEN_ID
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"labels": target_ids.to(dtype=torch.int64),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = transformers.HfArgumentParser(
|
||||||
|
(ModelArguments, DataArguments, CustomTrainingArguments)
|
||||||
|
)
|
||||||
|
assert len(sys.argv) == 2 and sys.argv[1].endswith(".json")
|
||||||
|
(
|
||||||
|
model_args,
|
||||||
|
data_args,
|
||||||
|
training_args,
|
||||||
|
) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||||
|
|
||||||
|
is_main_process = training_args.local_rank in [-1, 0]
|
||||||
|
if training_args.report_to == "wandb" and is_main_process:
|
||||||
|
wandb.init(
|
||||||
|
project="llm_tts",
|
||||||
|
config=training_args.to_sanitized_dict(),
|
||||||
|
name=training_args.run_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_args.llm_model_name_or_path,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_args.llm_model_name_or_path)
|
||||||
|
original_tokenizer_vocab_size = len(tokenizer)
|
||||||
|
cosyvoice2_token_size = 6561
|
||||||
|
new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
|
||||||
|
"<|SPEECH_GENERATION_START|>"
|
||||||
|
]
|
||||||
|
num_added_tokens = tokenizer.add_tokens(new_tokens)
|
||||||
|
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
model.vocab_size = len(tokenizer)
|
||||||
|
|
||||||
|
dataset = load_dataset("json", data_files=data_args.data_path)
|
||||||
|
dataset = dataset["train"]
|
||||||
|
train_test_split = dataset.train_test_split(test_size=100, seed=42)
|
||||||
|
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
data_collator=lambda features: data_collator(
|
||||||
|
features, tokenizer, original_tokenizer_vocab_size
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_main_process:
|
||||||
|
trainer.add_callback(transformers.integrations.WandbCallback())
|
||||||
|
|
||||||
|
trainer.train(resume_from_checkpoint=None)
|
||||||
|
trainer.save_model(training_args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
200
egs/emilia/TTS/local/extract_cosyvoice2_token.py
Normal file
200
egs/emilia/TTS/local/extract_cosyvoice2_token.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
|
||||||
|
# 2025 (authors: Yuekai Zhang)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Example Usage
|
||||||
|
torchrun --nproc_per_node=8 --nnodes=1 \
|
||||||
|
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
||||||
|
local/extract_cosyvoice2_token.py --data_dir $data_dir \
|
||||||
|
--jsonl_file $jsonl_file_basename \
|
||||||
|
--device "cuda" \
|
||||||
|
--output_dir $output_dir \
|
||||||
|
--batch_size 32 \
|
||||||
|
--num_workers 2 \
|
||||||
|
--model "speech_tokenizer_v2_25hz"
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import s3tokenizer
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from lhotse.serialization import load_jsonl
|
||||||
|
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class AudioDataset(Dataset):
|
||||||
|
def __init__(self, data_dir, jsonl_file):
|
||||||
|
self.data = []
|
||||||
|
# convert data_dir to Path object
|
||||||
|
self.data_dir = Path(data_dir)
|
||||||
|
# jsonl_files = self.data_dir.glob("*.jsonl")
|
||||||
|
jsonl_files = [self.data_dir / jsonl_file]
|
||||||
|
for jsonl_file in jsonl_files:
|
||||||
|
for item in tqdm(
|
||||||
|
# Note: People's Speech manifest.json is really a JSONL.
|
||||||
|
load_jsonl(jsonl_file),
|
||||||
|
desc=f"Processing {jsonl_file}",
|
||||||
|
):
|
||||||
|
self.data.append(item)
|
||||||
|
break
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
file_path = self.data_dir / self.data[idx]["wav"]
|
||||||
|
audio = s3tokenizer.load_audio(file_path)
|
||||||
|
if audio.shape[0] / 16000 > 30:
|
||||||
|
print(
|
||||||
|
f"do not support extract speech token for audio longer than 30s, file_path: {file_path}" # noqa
|
||||||
|
)
|
||||||
|
mel = torch.zeros(128, 0)
|
||||||
|
else:
|
||||||
|
mel = s3tokenizer.log_mel_spectrogram(audio)
|
||||||
|
return self.data[idx], mel
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(batch):
|
||||||
|
keys = [item[0] for item in batch]
|
||||||
|
mels = [item[1] for item in batch]
|
||||||
|
mels, mels_lens = s3tokenizer.padding(mels)
|
||||||
|
return keys, mels, mels_lens
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed():
|
||||||
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
rank = int(os.environ.get("RANK", 0))
|
||||||
|
print(
|
||||||
|
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
||||||
|
+ ", rank {}, world_size {}".format(rank, world_size)
|
||||||
|
)
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
dist.init_process_group("nccl")
|
||||||
|
return world_size, local_rank, rank
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(description="extract speech code")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
choices=[
|
||||||
|
"speech_tokenizer_v1",
|
||||||
|
"speech_tokenizer_v1_25hz",
|
||||||
|
"speech_tokenizer_v2_25hz",
|
||||||
|
],
|
||||||
|
help="model version",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_dir",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="each line contains `wav_name wav_path`",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--jsonl_file",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="each line contains `wav_name wav_path`",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
choices=["cuda", "cpu"],
|
||||||
|
help="device for inference",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir", required=True, type=str, help="dir to save result"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size",
|
||||||
|
required=True,
|
||||||
|
type=int,
|
||||||
|
help="batch size (per-device) for inference",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_workers", type=int, default=4, help="workers for dataloader"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefetch", type=int, default=5, help="prefetch for dataloader"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if args.device == "cuda":
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
world_size, local_rank, rank = init_distributed()
|
||||||
|
else:
|
||||||
|
world_size, local_rank, rank = 1, 0, 0
|
||||||
|
|
||||||
|
device = torch.device(args.device)
|
||||||
|
model = s3tokenizer.load_model(args.model).to(device)
|
||||||
|
dataset = AudioDataset(args.data_dir, args.jsonl_file)
|
||||||
|
|
||||||
|
if args.device == "cuda":
|
||||||
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
|
model, device_ids=[local_rank]
|
||||||
|
)
|
||||||
|
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
||||||
|
else:
|
||||||
|
sampler = None
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
sampler=sampler,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
prefetch_factor=args.prefetch,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
total_steps = len(dataset)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
||||||
|
|
||||||
|
writer = open(f"{args.output_dir}/part_{rank + 1}_of_{world_size}", "w")
|
||||||
|
for keys, mels, mels_lens in dataloader:
|
||||||
|
codes, codes_lens = model(mels.to(device), mels_lens.to(device))
|
||||||
|
for i, k in enumerate(keys):
|
||||||
|
code = codes[i, : codes_lens[i].item()].tolist()
|
||||||
|
k["code"] = code
|
||||||
|
writer.write(json.dumps(k, ensure_ascii=False) + "\n")
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar.update(world_size * len(keys))
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar.close()
|
||||||
|
writer.close()
|
||||||
|
if args.device == "cuda":
|
||||||
|
dist.barrier()
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
115
egs/emilia/TTS/prepare.sh
Executable file
115
egs/emilia/TTS/prepare.sh
Executable file
@ -0,0 +1,115 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -eou pipefail
|
||||||
|
|
||||||
|
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||||
|
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||||
|
|
||||||
|
stage=3
|
||||||
|
stop_stage=4
|
||||||
|
|
||||||
|
# Please download the OpenDataLab format from HuggingFace, you can specify the revision argument to fc71e07e8572f5f3be1dbd02ed3172a4d298f152, which is the old format.
|
||||||
|
# https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07e8572f5f3be1dbd02ed3172a4d298f152
|
||||||
|
dl_dir=$PWD/download
|
||||||
|
|
||||||
|
prefix="emilia"
|
||||||
|
# zh, en, ja, ko, de, fr
|
||||||
|
lang_set=("de" "en" "zh" "ja" "ko" "fr")
|
||||||
|
. shared/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
|
||||||
|
# All files generated by this script are saved in "data".
|
||||||
|
# You can safely remove "data" and rerun this script to regenerate it.
|
||||||
|
mkdir -p data
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
|
log "dl_dir: $dl_dir"
|
||||||
|
log "Stage 0: Download data"
|
||||||
|
# Extract the downloaded data:
|
||||||
|
cat $dl_dir/raw/EN/EN_B00008.tar.gz.* > $dl_dir/raw/EN/EN_B00008.tar.gz
|
||||||
|
for lang in "${lang_set[@]}"; do
|
||||||
|
lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]')
|
||||||
|
folder=$dl_dir/raw/${lang_upper}
|
||||||
|
for file in $folder/*.tar.gz; do
|
||||||
|
echo "Processing ${file}"
|
||||||
|
tar -xzvf $file -C $folder
|
||||||
|
done
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
|
log "Stage 1: Prepare emilia manifest (used by ./f5-tts)"
|
||||||
|
# We assume that you have downloaded the Emilia corpus
|
||||||
|
# to $dl_dir/emilia
|
||||||
|
mkdir -p data/manifests
|
||||||
|
for lang in "${lang_set[@]}"; do
|
||||||
|
echo "Processing ${lang}"
|
||||||
|
if [ ! -e data/manifests/.emilia.${lang}.done ]; then
|
||||||
|
lhotse prepare emilia $dl_dir data/manifests --num-jobs 30 --lang "${lang}"
|
||||||
|
touch data/manifests/.emilia.${lang}.done
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
|
log "Stage 2: Generate fbank (used by ./f5-tts)"
|
||||||
|
mkdir -p data/fbank
|
||||||
|
for lang in "${lang_set[@]}"; do
|
||||||
|
echo "Processing ${lang}"
|
||||||
|
if [ ! -e data/fbank/.emilia.${lang}.done ]; then
|
||||||
|
./local/compute_mel_feat.py --dataset-parts $lang --split 100 --prefix ${prefix}
|
||||||
|
touch data/fbank/.emilia.${lang}.done
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
|
log "Stage 3: Extract cosyvoice2 FSQ token (used by ./llaasa_cosyvoice2_token)"
|
||||||
|
for lang in "${lang_set[@]}"; do
|
||||||
|
lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]')
|
||||||
|
data_dir=$dl_dir/raw/${lang_upper}
|
||||||
|
# for all jsonl files in data_dir
|
||||||
|
for jsonl_file in $data_dir/*.jsonl; do
|
||||||
|
# get the file basename
|
||||||
|
jsonl_file_basename=$(basename $jsonl_file)
|
||||||
|
echo "Processing $jsonl_file"
|
||||||
|
output_dir="./cosy_v2_tokens_${lang_upper}/${jsonl_file_basename%.jsonl}"
|
||||||
|
echo "output_dir: $output_dir"
|
||||||
|
# skip if the output_dir exists
|
||||||
|
if [ -e $output_dir ]; then
|
||||||
|
echo "Output directory $output_dir already exists, skipping"
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
mkdir -p $output_dir
|
||||||
|
torchrun --nproc_per_node=8 --nnodes=1 \
|
||||||
|
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
||||||
|
local/extract_cosyvoice2_token.py --data_dir $data_dir \
|
||||||
|
--jsonl_file $jsonl_file_basename \
|
||||||
|
--device "cuda" \
|
||||||
|
--output_dir $output_dir \
|
||||||
|
--batch_size 32 \
|
||||||
|
--num_workers 2 \
|
||||||
|
--model "speech_tokenizer_v2_25hz" # or "speech_tokenizer_v1_25hz
|
||||||
|
done
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
log "Stage 4: Merge cosyvoice2 FSQ token (used by ./llaasa_cosyvoice2_token)"
|
||||||
|
for lang in "${lang_set[@]}"; do
|
||||||
|
lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]')
|
||||||
|
cosy_token_dir="./cosy_v2_tokens_${lang_upper}"
|
||||||
|
for dir in $cosy_token_dir/*; do
|
||||||
|
echo "Processing $dir"
|
||||||
|
dir_basename=$(basename $dir)
|
||||||
|
echo "dir_basename: $dir_basename"
|
||||||
|
cat $dir/part* > $dir/${dir_basename}.jsonl
|
||||||
|
done
|
||||||
|
cat $cosy_token_dir/${lang_upper}*/*.jsonl > $cosy_token_dir/cosy_v2_tokens_${lang_upper}.jsonl
|
||||||
|
done
|
||||||
|
fi
|
1
egs/emilia/TTS/shared
Symbolic link
1
egs/emilia/TTS/shared
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../icefall/shared/
|
@ -9,20 +9,6 @@
|
|||||||
|
|
||||||
[**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset.
|
[**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset.
|
||||||
|
|
||||||
> [!CAUTION]
|
|
||||||
> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS).
|
|
||||||
> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities.
|
|
||||||
>
|
|
||||||
> By using this framework, you agree to the following:
|
|
||||||
> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data.
|
|
||||||
>
|
|
||||||
> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology.
|
|
||||||
>
|
|
||||||
> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required.
|
|
||||||
>
|
|
||||||
> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties.
|
|
||||||
|
|
||||||
|
|
||||||
# [VALL-E](https://arxiv.org/abs/2301.02111)
|
# [VALL-E](https://arxiv.org/abs/2301.02111)
|
||||||
|
|
||||||
./valle contains the code for training VALL-E TTS model.
|
./valle contains the code for training VALL-E TTS model.
|
||||||
@ -186,3 +172,5 @@ bash local/compute_wer.sh $output_dir $manifest
|
|||||||
- [VALL-E](https://github.com/lifeiteng/vall-e)
|
- [VALL-E](https://github.com/lifeiteng/vall-e)
|
||||||
- [F5-TTS](https://github.com/SWivid/F5-TTS)
|
- [F5-TTS](https://github.com/SWivid/F5-TTS)
|
||||||
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
|
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
|
||||||
|
- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer/tree/main)
|
||||||
|
- [Spark-TTS](https://github.com/SparkAudio/Spark-TTS)
|
||||||
|
@ -108,13 +108,6 @@ def get_parser():
|
|||||||
help="Interpolate semantic token to match mel frames for CosyVoice",
|
help="Interpolate semantic token to match mel frames for CosyVoice",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-cosyvoice-semantic-token",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="Whether to use cosyvoice semantic token to replace text token.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--split-name",
|
"--split-name",
|
||||||
type=str,
|
type=str,
|
||||||
|
373
egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py
Normal file
373
egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py
Normal file
@ -0,0 +1,373 @@
|
|||||||
|
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
|
||||||
|
# 2025 (authors: Yuekai Zhang)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
|
||||||
|
""" Example Usage
|
||||||
|
split=test_zh
|
||||||
|
llm_path=f5-tts/exp_zh/checkpoint-805000
|
||||||
|
huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
|
||||||
|
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
|
||||||
|
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x
|
||||||
|
vocoder=./bigvgan_v2_24khz_100band_256x
|
||||||
|
torchrun --nproc_per_node=2 \
|
||||||
|
f5-tts/infer_dist.py \
|
||||||
|
--output_dir $output_dir \
|
||||||
|
--batch_size 1 \
|
||||||
|
--num_workers 2 \
|
||||||
|
--llm-model-name-or-path $llm_path \
|
||||||
|
--flow-matching-model-path $model_path \
|
||||||
|
--decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
|
||||||
|
--use-cosyvoice-semantic-token True \
|
||||||
|
--vocoder-dir $vocoder \
|
||||||
|
--split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \
|
||||||
|
--tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import s3tokenizer
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchaudio
|
||||||
|
from bigvganinference import BigVGANInference
|
||||||
|
from datasets import load_dataset
|
||||||
|
from lhotse.serialization import load_jsonl
|
||||||
|
from llm_tts import LLMTTS
|
||||||
|
from model.modules import MelSpec
|
||||||
|
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||||
|
from tqdm import tqdm
|
||||||
|
from train import (
|
||||||
|
add_model_arguments,
|
||||||
|
get_model,
|
||||||
|
get_tokenizer,
|
||||||
|
interpolate_tokens,
|
||||||
|
load_F5_TTS_pretrained_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
|
||||||
|
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(description="extract speech code")
|
||||||
|
parser.add_argument(
|
||||||
|
"--s3-tokenizer-name",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
choices=[
|
||||||
|
"speech_tokenizer_v1",
|
||||||
|
"speech_tokenizer_v1_25hz",
|
||||||
|
"speech_tokenizer_v2_25hz",
|
||||||
|
],
|
||||||
|
help="model version",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--split-name",
|
||||||
|
type=str,
|
||||||
|
default="wenetspeech4tts",
|
||||||
|
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
||||||
|
help="huggingface dataset split name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir", required=True, type=str, help="dir to save result"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
required=True,
|
||||||
|
type=int,
|
||||||
|
help="batch size (per-device) for inference",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-workers", type=int, default=4, help="workers for dataloader"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefetch", type=int, default=5, help="prefetch for dataloader"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-model-name-or-path",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="model version",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer-dir",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="tokenizer dir",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vocoder-dir",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="vocoder dir",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--flow-matching-model-path",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="flow matching model path",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="top k for sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-p",
|
||||||
|
type=float,
|
||||||
|
default=0.95,
|
||||||
|
help="top p for sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
default=0.8,
|
||||||
|
help="temperature for sampling",
|
||||||
|
)
|
||||||
|
add_model_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def padded_mel_batch(ref_mels):
|
||||||
|
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
||||||
|
padded_ref_mels = []
|
||||||
|
for mel in ref_mels:
|
||||||
|
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
|
||||||
|
padded_ref_mels.append(padded_ref_mel)
|
||||||
|
padded_ref_mels = torch.stack(padded_ref_mels)
|
||||||
|
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
|
||||||
|
return padded_ref_mels
|
||||||
|
|
||||||
|
|
||||||
|
def data_collator(batch, tokenizer, mel_spectrogram):
|
||||||
|
speech_generation_start_index = tokenizer.convert_tokens_to_ids(
|
||||||
|
"<|SPEECH_GENERATION_START|>"
|
||||||
|
)
|
||||||
|
assistant_index = tokenizer.convert_tokens_to_ids("assistant")
|
||||||
|
target_sample_rate = 24000
|
||||||
|
hop_length = 256
|
||||||
|
target_rms = 0.1
|
||||||
|
input_ids_list, ref_mel_list, ref_mel_len_list = [], [], []
|
||||||
|
for i, item in enumerate(batch):
|
||||||
|
prompt_text, target_text, prompt_audio_codes = (
|
||||||
|
item["prompt_text"],
|
||||||
|
item["target_text"],
|
||||||
|
item["prompt_audio_cosy2_tokens"],
|
||||||
|
)
|
||||||
|
message = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Convert the text to speech: {prompt_text + target_text}",
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"},
|
||||||
|
]
|
||||||
|
|
||||||
|
input_ids = tokenizer.apply_chat_template(
|
||||||
|
message,
|
||||||
|
tokenize=True,
|
||||||
|
chat_template=TEMPLATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_audio_codes = [c + 151665 for c in prompt_audio_codes]
|
||||||
|
|
||||||
|
idx = input_ids.index(speech_generation_start_index)
|
||||||
|
input_ids = input_ids[:idx] + prompt_audio_codes
|
||||||
|
input_ids_list.append(input_ids)
|
||||||
|
|
||||||
|
# get flow matching model's prompt mel spectrogram
|
||||||
|
ref_audio_org, ref_sr = (
|
||||||
|
item["prompt_audio"]["array"],
|
||||||
|
item["prompt_audio"]["sampling_rate"],
|
||||||
|
)
|
||||||
|
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
|
||||||
|
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
||||||
|
if ref_rms < target_rms:
|
||||||
|
ref_audio_org = ref_audio_org * target_rms / ref_rms
|
||||||
|
|
||||||
|
if ref_sr != target_sample_rate:
|
||||||
|
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||||
|
ref_audio = resampler(ref_audio_org)
|
||||||
|
else:
|
||||||
|
ref_audio = ref_audio_org
|
||||||
|
|
||||||
|
# Duration in mel frame length
|
||||||
|
ref_mel_len = ref_audio.shape[-1] // hop_length
|
||||||
|
# to mel spectrogram
|
||||||
|
ref_mel = mel_spectrogram(ref_audio)
|
||||||
|
ref_mel = ref_mel.squeeze(0)
|
||||||
|
|
||||||
|
ref_mel_list.append(ref_mel)
|
||||||
|
ref_mel_len_list.append(ref_mel_len)
|
||||||
|
|
||||||
|
max_len = max([len(input_ids) for input_ids in input_ids_list])
|
||||||
|
input_ids_list = [
|
||||||
|
[tokenizer.pad_token_id] * (max_len - len(input_ids)) + input_ids
|
||||||
|
for input_ids in input_ids_list
|
||||||
|
]
|
||||||
|
input_ids = torch.tensor(input_ids_list, dtype=torch.int64)
|
||||||
|
attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
|
||||||
|
ids = [item["id"] for item in batch]
|
||||||
|
|
||||||
|
ref_mel_batch = padded_mel_batch(ref_mel_list)
|
||||||
|
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"ids": ids,
|
||||||
|
"ref_mel_batch": ref_mel_batch,
|
||||||
|
"ref_mel_len_batch": ref_mel_len_batch,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed():
|
||||||
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
rank = int(os.environ.get("RANK", 0))
|
||||||
|
print(
|
||||||
|
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
||||||
|
+ ", rank {}, world_size {}".format(rank, world_size)
|
||||||
|
)
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
dist.init_process_group("nccl")
|
||||||
|
return world_size, local_rank, rank
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
world_size, local_rank, rank = init_distributed()
|
||||||
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
|
model = LLMTTS(
|
||||||
|
model_dir=args.llm_model_name_or_path,
|
||||||
|
tokenizer_dir=args.tokenizer_dir,
|
||||||
|
s3_tokenizer_name=args.s3_tokenizer_name,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
vocoder = BigVGANInference.from_pretrained(args.vocoder_dir, use_cuda_kernel=False)
|
||||||
|
vocoder = vocoder.eval().to(device)
|
||||||
|
|
||||||
|
flow_matching_model = get_model(args).eval().to(device)
|
||||||
|
_ = load_checkpoint(
|
||||||
|
args.flow_matching_model_path,
|
||||||
|
model=flow_matching_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = load_dataset(
|
||||||
|
"yuekai/seed_tts_cosy2",
|
||||||
|
split=args.split_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
||||||
|
|
||||||
|
mel_spectrogram = MelSpec(
|
||||||
|
n_fft=1024,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
n_mel_channels=100,
|
||||||
|
target_sample_rate=24000,
|
||||||
|
mel_spec_type="bigvgan",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
sampler=sampler,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
prefetch_factor=args.prefetch,
|
||||||
|
collate_fn=lambda x: data_collator(x, model.tokenizer, mel_spectrogram),
|
||||||
|
)
|
||||||
|
|
||||||
|
total_steps = len(dataset)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
||||||
|
|
||||||
|
for batch in dataloader:
|
||||||
|
generate_codes = model.inference_batch(
|
||||||
|
batch["input_ids"],
|
||||||
|
batch["attention_mask"],
|
||||||
|
top_k=args.top_k,
|
||||||
|
top_p=args.top_p,
|
||||||
|
temperature=args.temperature,
|
||||||
|
)
|
||||||
|
flow_matching_input_tokens, total_mel_lens = [], []
|
||||||
|
for i, code in enumerate(generate_codes):
|
||||||
|
flow_matching_input_token = interpolate_tokens(code)
|
||||||
|
total_mel_len = len(flow_matching_input_token)
|
||||||
|
flow_matching_input_tokens.append(flow_matching_input_token)
|
||||||
|
total_mel_lens.append(total_mel_len)
|
||||||
|
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
||||||
|
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch[
|
||||||
|
"ref_mel_len_batch"
|
||||||
|
].to(device)
|
||||||
|
|
||||||
|
max_len = max([len(tokens) for tokens in flow_matching_input_tokens])
|
||||||
|
# pad tokens to the same length
|
||||||
|
for i, tokens in enumerate(flow_matching_input_tokens):
|
||||||
|
flow_matching_input_tokens[i] = torch.tensor(
|
||||||
|
tokens + [-1] * (max_len - len(tokens)), dtype=torch.long
|
||||||
|
)
|
||||||
|
flow_matching_input_tokens = torch.stack(flow_matching_input_tokens).to(device)
|
||||||
|
generated, _ = flow_matching_model.sample(
|
||||||
|
cond=ref_mels,
|
||||||
|
text=flow_matching_input_tokens,
|
||||||
|
duration=total_mel_lens,
|
||||||
|
lens=ref_mel_lens,
|
||||||
|
steps=16,
|
||||||
|
cfg_strength=2.0,
|
||||||
|
sway_sampling_coef=-1,
|
||||||
|
no_ref_audio=False,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, gen in enumerate(generated):
|
||||||
|
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||||
|
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
||||||
|
|
||||||
|
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
||||||
|
target_rms = 0.1
|
||||||
|
target_sample_rate = 24_000
|
||||||
|
# if ref_rms_list[i] < target_rms:
|
||||||
|
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
||||||
|
utt = batch["ids"][i]
|
||||||
|
torchaudio.save(
|
||||||
|
f"{args.output_dir}/{utt}.wav",
|
||||||
|
generated_wave,
|
||||||
|
target_sample_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar.update(world_size * len(batch["ids"]))
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar.close()
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
110
egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py
Normal file
110
egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
# Copyright (c) 2025 SparkAudio
|
||||||
|
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
|
||||||
|
# 2025 Yuekai Zhang
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from https://github.com/SparkAudio/Spark-TTS/blob/main/cli/SparkTTS.py
|
||||||
|
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class LLMTTS:
|
||||||
|
"""
|
||||||
|
LLM-TTS for text-to-speech generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_dir: Path,
|
||||||
|
tokenizer_dir: Path,
|
||||||
|
s3_tokenizer_name: str,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initializes the LLMTTS model with the provided configurations and device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_dir (Path): Directory containing the model and config files.
|
||||||
|
tokenizer_dir (Path): Directory containing the tokenizer files.
|
||||||
|
s3_tokenizer_name (str): Name of the tokenizer file on S3.
|
||||||
|
device (torch.device): Device to run the model on.
|
||||||
|
"""
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_dir,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
||||||
|
self.original_vocab_size = len(tokenizer)
|
||||||
|
self.cosyvoice2_token_vocab_size = 6561
|
||||||
|
new_tokens = [f"<|s_{i}|>" for i in range(self.cosyvoice2_token_vocab_size)] + [
|
||||||
|
"<|SPEECH_GENERATION_START|>"
|
||||||
|
]
|
||||||
|
num_added_tokens = tokenizer.add_tokens(new_tokens)
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.assistant_index = tokenizer.convert_tokens_to_ids("assistant")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def inference_batch(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
temperature: float = 0.8,
|
||||||
|
top_k: float = 50,
|
||||||
|
top_p: float = 0.95,
|
||||||
|
max_new_tokens: int = 1024,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Performs inference to generate speech from text, incorporating prompt audio and/or text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids (torch.Tensor): Input IDs for the model.
|
||||||
|
attention_mask (torch.Tensor): Attention mask for the model.
|
||||||
|
temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
|
||||||
|
top_k (float, optional): Top-k sampling parameter. Default is 50.
|
||||||
|
top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.
|
||||||
|
max_new_tokens (int, optional): Maximum number of tokens to generate. Default is 1024.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Generated waveform as a tensor.
|
||||||
|
"""
|
||||||
|
generated_ids = self.model.generate(
|
||||||
|
input_ids=input_ids.to(self.device),
|
||||||
|
attention_mask=attention_mask.to(self.device),
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
do_sample=True,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
generated_ids = generated_ids.cpu().tolist()
|
||||||
|
for i in range(len(generated_ids)):
|
||||||
|
assistant_index = generated_ids[i].index(self.assistant_index)
|
||||||
|
padding_index = len(generated_ids[i])
|
||||||
|
# WAR: harding coding assistant_index + 2, for the current template Assistant: \n
|
||||||
|
result = generated_ids[i][assistant_index + 2 :]
|
||||||
|
result = [token - self.original_vocab_size for token in result]
|
||||||
|
result = [token for token in result if token >= 0]
|
||||||
|
results.append(result)
|
||||||
|
return results
|
@ -118,6 +118,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Number of Decoder layers.",
|
help="Number of Decoder layers.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-cosyvoice-semantic-token",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use cosyvoice semantic token to replace text token.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -313,13 +320,6 @@ def get_parser():
|
|||||||
help="perform OOM check on dataloader batches before starting training.",
|
help="perform OOM check on dataloader batches before starting training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-cosyvoice-semantic-token",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="Whether to use cosyvoice semantic token to replace text token.",
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
Loading…
x
Reference in New Issue
Block a user