Add F5-TTS with semantic token training results (#1880)

* add cosy token

* update inference code

* add extract cosy token

* update results

* add requirements.txt

* update readme

---------

Co-authored-by: yuekaiz <yuekaiz@h20-7.cm.cluster>
Co-authored-by: yuekaiz <yuekaiz@mgmt1-login.cm.cluster>
This commit is contained in:
Yuekai Zhang 2025-02-24 13:58:47 +08:00 committed by GitHub
parent da597ad782
commit 2ba665abca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 770 additions and 28 deletions

View File

@ -1,3 +1,10 @@
# Results
| Model | Seed-TTS test_zh CER | Comment |
|---------------------------------------|---------------------|--------|
| [vall-e](./valle) | 4.33% | ~150M |
| [f5-tts](./f5-tts) | 3.02% (16 steps) / 2.42% (32 steps) | F5-TTS-Small Config, ~155M |
| [f5-tts-semantic-token](./f5-tts) | 1.79% (16 steps) | Using pretrained cosyvoice2 semantic tokens as inputs rather than text tokens, ~155M |
# Introduction
[**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset.
@ -131,6 +138,51 @@ accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-f
bash local/compute_wer.sh $output_dir $manifest
```
# F5-TTS-Semantic-Token
./f5-tts contains the code for training F5-TTS-Semantic-Token. We replaced the text tokens in F5-TTS with pretrained cosyvoice2 semantic tokens. During inference, we use the pretrained CosyVoice2 LLM to predict the semantic tokens for target audios. We observed that this approach leads to faster convergence and improved prosody modeling results.
Generated samples and training logs of wenetspeech basic 7k hours data can be found [here](https://huggingface.co/yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic/tree/main).
Preparation:
```
# extract cosyvoice2 semantic tokens
bash prepare.sh --stage 5 --stop_stage 7
```
The training command is given below:
```
# docker: ghcr.io/swivid/f5-tts:main
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece
world_size=8
exp_dir=exp/f5-tts-semantic-token-small
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
--base-lr 1e-4 --warmup-steps 20000 --average-period 0 \
--num-epochs 10 --start-epoch 1 --start-batch 0 \
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
--exp-dir ${exp_dir} --world-size ${world_size} \
--decay-steps 600000 --prefix wenetspeech4tts_cosy_token --use-cosyvoice-semantic-token True
```
To inference with Icefall Wenetspeech4TTS trained F5-Small-Semantic-Token, use:
```
huggingface-cli login
huggingface-cli download --local-dir ${exp_dir} yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
split=test_zh
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True
bash local/compute_wer.sh $output_dir $manifest
```
# Credits
- [VALL-E](https://github.com/lifeiteng/vall-e)
- [F5-TTS](https://github.com/SWivid/F5-TTS)
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)

View File

@ -11,7 +11,14 @@ python3 f5-tts/generate_averaged_model.py \
--epoch 56 \
--avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
--exp-dir exp/f5_small
# command for text token input
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18
# command for cosyvoice semantic token input
split=test_zh # seed_tts_eval test_zh
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True
bash local/compute_wer.sh $output_dir $manifest
"""
import argparse
@ -22,6 +29,7 @@ import random
import time
from pathlib import Path
import datasets
import torch
import torch.nn.functional as F
import torchaudio
@ -36,10 +44,12 @@ from train import (
add_model_arguments,
get_model,
get_tokenizer,
interpolate_tokens,
load_F5_TTS_pretrained_checkpoint,
)
from icefall.checkpoint import load_checkpoint
from icefall.utils import str2bool
def get_parser():
@ -78,7 +88,7 @@ def get_parser():
parser.add_argument(
"--manifest-file",
type=str,
default="/path/seed_tts_eval/seedtts_testset/zh/meta.lst",
default=None,
help="The manifest file in seed_tts_eval format",
)
@ -90,6 +100,29 @@ def get_parser():
)
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
parser.add_argument(
"--interpolate-token",
type=str2bool,
default=True,
help="Interpolate semantic token to match mel frames for CosyVoice",
)
parser.add_argument(
"--use-cosyvoice-semantic-token",
type=str2bool,
default=False,
help="Whether to use cosyvoice semantic token to replace text token.",
)
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="huggingface dataset split name",
)
add_model_arguments(parser)
return parser.parse_args()
@ -243,6 +276,392 @@ def get_inference_prompt(
return prompts_all
def get_inference_prompt_cosy_voice_huggingface(
dataset,
speed=1.0,
tokenizer="pinyin",
polyphone=True,
target_sample_rate=24000,
n_fft=1024,
win_length=1024,
n_mel_channels=100,
hop_length=256,
mel_spec_type="bigvgan",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
num_buckets=200,
min_secs=3,
max_secs=40,
interpolate_token=False,
):
prompts_all = []
min_tokens = min_secs * target_sample_rate // hop_length
max_tokens = max_secs * target_sample_rate // hop_length
batch_accum = [0] * num_buckets
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
[[] for _ in range(num_buckets)] for _ in range(6)
)
mel_spectrogram = MelSpec(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
)
for i in range(len(dataset)):
utt = dataset[i]["id"]
ref_audio_org, ref_sr = (
dataset[i]["prompt_audio"]["array"],
dataset[i]["prompt_audio"]["sampling_rate"],
)
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
audio_tokens = dataset[i]["target_audio_cosy2_tokens"]
prompt_audio_tokens = dataset[i]["prompt_audio_cosy2_tokens"]
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
if ref_rms < target_rms:
ref_audio_org = ref_audio_org * target_rms / ref_rms
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
input_tokens = prompt_audio_tokens + audio_tokens
if interpolate_token:
input_tokens = interpolate_tokens(input_tokens)
text_list = input_tokens
# Duration, mel frame length
ref_mel_len = ref_audio.shape[-1] // hop_length
total_mel_len = len(input_tokens)
if not interpolate_token:
total_mel_len = int(total_mel_len / 4 * 15)
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
if total_mel_len > max_tokens:
print(
f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
)
continue
assert (
min_tokens <= total_mel_len <= max_tokens
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
bucket_i = math.floor(
(total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets
)
utts[bucket_i].append(utt)
ref_rms_list[bucket_i].append(ref_rms)
ref_mels[bucket_i].append(ref_mel)
ref_mel_lens[bucket_i].append(ref_mel_len)
total_mel_lens[bucket_i].append(total_mel_len)
# final_text_list[bucket_i].extend(text_list)
final_text_list[bucket_i].append(text_list)
batch_accum[bucket_i] += total_mel_len
if batch_accum[bucket_i] >= infer_batch_size:
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
batch_accum[bucket_i] = 0
(
utts[bucket_i],
ref_rms_list[bucket_i],
ref_mels[bucket_i],
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
) = (
[],
[],
[],
[],
[],
[],
)
# add residual
for bucket_i, bucket_frames in enumerate(batch_accum):
if bucket_frames > 0:
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
# not only leave easy work for last workers
random.seed(666)
random.shuffle(prompts_all)
return prompts_all
def inference_speech_token(
cosyvoice,
tts_text,
prompt_text,
prompt_speech_16k,
stream=False,
speed=1.0,
text_frontend=True,
):
tokens = []
prompt_text = cosyvoice.frontend.text_normalize(
prompt_text, split=False, text_frontend=text_frontend
)
for i in cosyvoice.frontend.text_normalize(
tts_text, split=True, text_frontend=text_frontend
):
tts_text_token, tts_text_token_len = cosyvoice.frontend._extract_text_token(i)
(
prompt_text_token,
prompt_text_token_len,
) = cosyvoice.frontend._extract_text_token(prompt_text)
speech_token, speech_token_len = cosyvoice.frontend._extract_speech_token(
prompt_speech_16k
)
for i in cosyvoice.model.llm.inference(
text=tts_text_token.to(cosyvoice.model.device),
text_len=torch.tensor([tts_text_token.shape[1]], dtype=torch.int32).to(
cosyvoice.model.device
),
prompt_text=prompt_text_token.to(cosyvoice.model.device),
prompt_text_len=torch.tensor(
[prompt_text_token.shape[1]], dtype=torch.int32
).to(cosyvoice.model.device),
prompt_speech_token=speech_token.to(cosyvoice.model.device),
prompt_speech_token_len=torch.tensor(
[speech_token.shape[1]], dtype=torch.int32
).to(cosyvoice.model.device),
embedding=None,
):
tokens.append(i)
return tokens, speech_token
def get_inference_prompt_cosy_voice(
metainfo,
speed=1.0,
tokenizer="pinyin",
polyphone=True,
target_sample_rate=24000,
n_fft=1024,
win_length=1024,
n_mel_channels=100,
hop_length=256,
mel_spec_type="bigvgan",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
num_buckets=200,
min_secs=3,
max_secs=40,
interpolate_token=False,
):
import sys
# please change the path to the cosyvoice accordingly
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
sys.path.append("/workspace/CosyVoice")
from cosyvoice.cli.cosyvoice import CosyVoice2
# please download the cosyvoice model first
cosyvoice = CosyVoice2(
"/workspace/CosyVoice2-0.5B", load_jit=False, load_trt=False, fp16=False
)
prompts_all = []
min_tokens = min_secs * target_sample_rate // hop_length
max_tokens = max_secs * target_sample_rate // hop_length
batch_accum = [0] * num_buckets
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
[[] for _ in range(num_buckets)] for _ in range(6)
)
mel_spectrogram = MelSpec(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
)
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(
metainfo, desc="Processing prompts..."
):
# Audio
ref_audio_org, ref_sr = torchaudio.load(prompt_wav)
# cosy voice
if ref_sr != 16000:
resampler = torchaudio.transforms.Resample(ref_sr, 16000)
ref_audio_16k = resampler(ref_audio_org)
else:
ref_audio_16k = ref_audio_org
audio_tokens, prompt_audio_tokens = inference_speech_token(
cosyvoice, gt_text, prompt_text, ref_audio_16k, stream=False
)
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
if ref_rms < target_rms:
ref_audio_org = ref_audio_org * target_rms / ref_rms
assert (
ref_audio_org.shape[-1] > 5000
), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
# Text
# if len(prompt_text[-1].encode("utf-8")) == 1:
# prompt_text = prompt_text + " "
# text = [prompt_text + gt_text]
# if tokenizer == "pinyin":
# text_list = convert_char_to_pinyin(text, polyphone=polyphone)
# else:
# text_list = text
# concat two tensors: prompt audio tokens with audio tokens --> shape 1, prompt_audio_tokens + audio_tokens
# prompt_audio_tokens shape 1, prompt_audio_tokens
# audio_tokens shape 1, audio_tokens
prompt_audio_tokens = prompt_audio_tokens.squeeze().cpu().tolist()
input_tokens = prompt_audio_tokens + audio_tokens
# convert it into a list
# input_tokens_list = input_tokens.squeeze().cpu().tolist()
if interpolate_token:
input_tokens = interpolate_tokens(input_tokens)
text_list = input_tokens
# Duration, mel frame length
ref_mel_len = ref_audio.shape[-1] // hop_length
if use_truth_duration:
gt_audio, gt_sr = torchaudio.load(gt_wav)
if gt_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
gt_audio = resampler(gt_audio)
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
# # test vocoder resynthesis
# ref_audio = gt_audio
else:
ref_text_len = len(prompt_text.encode("utf-8"))
gen_text_len = len(gt_text.encode("utf-8"))
total_mel_len_compute = ref_mel_len + int(
ref_mel_len / ref_text_len * gen_text_len / speed
)
total_mel_len = len(input_tokens)
if not interpolate_token:
total_mel_len = int(total_mel_len / 4 * 15)
print(
f"total_mel_len_compute: {total_mel_len_compute}, total_mel_len: {total_mel_len}"
)
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
assert (
min_tokens <= total_mel_len <= max_tokens
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
bucket_i = math.floor(
(total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets
)
utts[bucket_i].append(utt)
ref_rms_list[bucket_i].append(ref_rms)
ref_mels[bucket_i].append(ref_mel)
ref_mel_lens[bucket_i].append(ref_mel_len)
total_mel_lens[bucket_i].append(total_mel_len)
# final_text_list[bucket_i].extend(text_list)
final_text_list[bucket_i].append(text_list)
batch_accum[bucket_i] += total_mel_len
if batch_accum[bucket_i] >= infer_batch_size:
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
batch_accum[bucket_i] = 0
(
utts[bucket_i],
ref_rms_list[bucket_i],
ref_mels[bucket_i],
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
) = (
[],
[],
[],
[],
[],
[],
)
# add residual
for bucket_i, bucket_frames in enumerate(batch_accum):
if bucket_frames > 0:
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
# not only leave easy work for last workers
random.seed(666)
random.shuffle(prompts_all)
return prompts_all
def padded_mel_batch(ref_mels):
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
padded_ref_mels = []
@ -275,20 +694,55 @@ def main():
accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"
metainfo = get_seedtts_testset_metainfo(args.manifest_file)
prompts_all = get_inference_prompt(
metainfo,
speed=1.0,
tokenizer="pinyin",
target_sample_rate=24_000,
n_mel_channels=100,
hop_length=256,
mel_spec_type="bigvgan",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
)
if args.manifest_file:
metainfo = get_seedtts_testset_metainfo(args.manifest_file)
if not args.use_cosyvoice_semantic_token:
prompts_all = get_inference_prompt(
metainfo,
speed=1.0,
tokenizer="pinyin",
target_sample_rate=24_000,
n_mel_channels=100,
hop_length=256,
mel_spec_type="bigvgan",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
)
else:
prompts_all = get_inference_prompt_cosy_voice(
metainfo,
speed=1.0,
tokenizer="pinyin",
target_sample_rate=24_000,
n_mel_channels=100,
hop_length=256,
mel_spec_type="bigvgan",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
interpolate_token=args.interpolate_token,
)
else:
assert args.use_cosyvoice_semantic_token
dataset = datasets.load_dataset(
"yuekai/seed_tts_cosy2",
split=args.split_name,
trust_remote_code=True,
)
prompts_all = get_inference_prompt_cosy_voice_huggingface(
dataset,
speed=1.0,
tokenizer="pinyin",
target_sample_rate=24_000,
n_mel_channels=100,
hop_length=256,
mel_spec_type="bigvgan",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
interpolate_token=args.interpolate_token,
)
vocoder = BigVGANInference.from_pretrained(
"./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False
@ -324,6 +778,16 @@ def main():
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
if args.use_cosyvoice_semantic_token:
# concat final_text_list
max_len = max([len(tokens) for tokens in final_text_list])
# pad tokens to the same length
for i, tokens in enumerate(final_text_list):
final_text_list[i] = torch.tensor(
tokens + [-1] * (max_len - len(tokens)), dtype=torch.long
)
final_text_list = torch.stack(final_text_list).to(device)
# Inference
with torch.inference_mode():
generated, _ = model.sample(

View File

@ -0,0 +1,36 @@
# F5-TTS
accelerate>=0.33.0
bitsandbytes>0.37.0
cached_path
click
datasets
ema_pytorch>=0.5.2
gradio>=3.45.2
hydra-core>=1.3.0
jieba
librosa
matplotlib
numpy<=1.26.4
pydub
pypinyin
safetensors
soundfile
tomli
torch>=2.0.0
torchaudio>=2.0.0
torchdiffeq
tqdm>=4.65.0
transformers
x_transformers>=1.31.14
# icefall
kaldialign
lhotse
tensorboard
bigvganinference
sentencepiece
sherpa-onnx
k2
# semantic experiment
s3tokenizer

View File

@ -82,9 +82,12 @@ class SpeechSynthesisDataset(torch.utils.data.Dataset):
text = [cut.supervisions[0].text for cut in cuts]
batch["text"] = text
if self.return_tokens:
if self.return_tokens and "speech_tokens" in cuts[0].supervisions[0].custom:
# tokens = [cut.tokens for cut in cuts]
tokens = [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts]
# tokens = [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts]
tokens = [cut.supervisions[0].custom["speech_tokens"] for cut in cuts]
# change str into list
tokens = [list(map(int, token.split())) for token in tokens]
batch["tokens"] = tokens
if self.return_spk_ids:

View File

@ -31,6 +31,16 @@ python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-ma
--base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
--exp-dir ${exp_dir} --world-size ${world_size}
# command for training with cosyvoice semantic token
exp_dir=exp/f5-tts-cosyvoice
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
--base-lr 1e-4 --warmup-steps 20000 --average-period 0 \
--num-epochs 10 --start-epoch 1 --start-batch 0 \
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
--exp-dir ${exp_dir} --world-size ${world_size} \
--decay-steps 600000 --prefix wenetspeech4tts_cosy_token --use-cosyvoice-semantic-token True
"""
import argparse
@ -303,6 +313,13 @@ def get_parser():
help="perform OOM check on dataloader batches before starting training.",
)
parser.add_argument(
"--use-cosyvoice-semantic-token",
type=str2bool,
default=False,
help="Whether to use cosyvoice semantic token to replace text token.",
)
add_model_arguments(parser)
return parser
@ -378,7 +395,11 @@ def get_tokenizer(vocab_file_path: str):
def get_model(params):
vocab_char_map, vocab_size = get_tokenizer(params.tokens)
if params.use_cosyvoice_semantic_token:
# https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/file/view/master?fileName=cosyvoice.yaml&status=1#L36
vocab_char_map, vocab_size = None, 6561
else:
vocab_char_map, vocab_size = get_tokenizer(params.tokens)
# bigvgan 100 dim features
n_mel_channels = 100
n_fft = 1024
@ -556,14 +577,44 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename)
def prepare_input(batch: dict, device: torch.device):
"""Parse batch data"""
text_inputs = batch["text"]
# texts.extend(convert_char_to_pinyin([text], polyphone=true))
text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
def interpolate_tokens(cosy_tokens, pad_token=-1):
"""Interpolate cosyvoice tokens to match bigvgan frames length"""
# cosyvoice, 25 tokens/sec
# bigvgan sample_rate/hop_length 24000/256 frames/sec
# For every 4 cosyvoice tokens, insert pad tokens to extend it to 15 tokens to match bigvgan frames length
# We choose 4,4,4,3 to match 15 frames
three, two = [pad_token] * 3, [pad_token] * 2
return [
x
for i, e in enumerate(cosy_tokens)
for x in ([e] + three if i % 4 < 3 else [e] + two)
]
def prepare_input(
batch: dict, device: torch.device, use_cosyvoice_semantic_token: bool
):
"""Parse batch data"""
mel_spec = batch["features"]
mel_lengths = batch["features_lens"]
if use_cosyvoice_semantic_token:
semantic_tokens = []
for i in range(len(batch["tokens"])):
tokens = batch["tokens"][i]
tokens = interpolate_tokens(tokens)
semantic_tokens.append(tokens)
# pad to the same length, B,T, with pad value -1
max_len = max([len(tokens) for tokens in semantic_tokens])
text_inputs = torch.full(
(len(semantic_tokens), max_len), -1, dtype=torch.long
).to(device)
for i, tokens in enumerate(semantic_tokens):
text_inputs[i, : len(tokens)] = torch.tensor(tokens, dtype=torch.long)
else:
text_inputs = batch["text"]
text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
return text_inputs, mel_spec.to(device), mel_lengths.to(device)
@ -593,7 +644,11 @@ def compute_loss(
values >= 1.0 are fully warmed up and have all modules present.
"""
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
(text_inputs, mel_spec, mel_lengths) = prepare_input(batch, device=device)
(text_inputs, mel_spec, mel_lengths) = prepare_input(
batch,
device=device,
use_cosyvoice_semantic_token=params.use_cosyvoice_semantic_token,
)
# at entry, TextTokens is (N, P)
with torch.set_grad_enabled(is_training):

View File

@ -174,7 +174,7 @@ class TtsDataModule:
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=True,
return_tokens=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
@ -234,7 +234,7 @@ class TtsDataModule:
else:
validate = SpeechSynthesisDataset(
return_text=True,
return_tokens=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
@ -265,7 +265,7 @@ class TtsDataModule:
else:
test = SpeechSynthesisDataset(
return_text=True,
return_tokens=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)

View File

@ -0,0 +1,108 @@
#!/usr/bin/env python3
# Copyright 2025 author: Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gzip
import json
import logging
import s3tokenizer
from lhotse import CutSet, load_manifest_lazy
from tqdm import tqdm
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--manifest-dir",
type=str,
default="data/fbank",
help="Directory to store the manifest files",
)
parser.add_argument(
"--jsonl-prefix",
type=str,
default="wenetspeech4tts_cuts_valid",
help="The training subset for wenetspeech.",
)
parser.add_argument(
"--tokens-path",
type=str,
default="./s3_tokens_valid/wenetspeech4tts_valid.json",
help="json file containing the speech tokens",
)
return parser
def get_speech_tokens(tokens_path):
id2tokens = {}
with open(tokens_path, "r") as fin:
for line in fin:
line = json.loads(line)
id2tokens[line["key"]] = " ".join(map(str, line["code"]))
return id2tokens
def attach_manifest(manifest, fixed_manifest_path, id2tokens):
with CutSet.open_writer(fixed_manifest_path) as manifest_writer:
fixed_item = 0
for i, cut in enumerate(tqdm(manifest)):
cut_id = cut.supervisions[0].id
if cut_id in id2tokens:
code = id2tokens[cut_id]
cut.supervisions[0].custom = {
**cut.supervisions[0].custom,
**{"speech_tokens": code},
}
else:
print(f"cut_id {cut_id} not in id2tokens")
fixed_item += 1
manifest_writer.write(cut)
logging.info(f"Fixed {fixed_item} items in the manifest")
def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
manifest_path = args.manifest_dir + "/" + f"{args.jsonl_prefix}.jsonl.gz"
attached_manifest_path = (
args.manifest_dir + "/" + f"{args.jsonl_prefix}_attached_cosyvoice_v2.jsonl.gz"
)
logging.info(f"Loading manifest from {manifest_path}")
cuts_manifest = load_manifest_lazy(manifest_path)
logging.info(f"Loading manifest from {manifest_path} done")
id2tokens = get_speech_tokens(args.tokens_path)
logging.info(f"Loaded id2tokens with {len(id2tokens)} entries")
attach_manifest(cuts_manifest, attached_manifest_path, id2tokens)
logging.info(
f"Manifest with speech tokens attached is saved to {attached_manifest_path}"
)
if __name__ == "__main__":
main()

View File

@ -111,7 +111,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 7: Split the ${prefix} cuts into train, valid and test sets (used by ./f5-tts)"
log "Stage 6: Split the ${prefix} cuts into train, valid and test sets (used by ./f5-tts)"
if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then
echo "Combining ${prefix} cuts"
pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz")
@ -139,3 +139,27 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
touch data/fbank/.${prefix}_split.done
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Extract cosyvoice2 FSQ token (used by ./f5-tts semantic token experiment)"
split_name=("valid" "test" "train")
for split in "${split_name[@]}"; do
echo "Processing $split"
wav_scp_file=wav_${split}.scp
output_dir="./cosy_v2_tokens_${split}"
oringinal_jsonl_file=data/fbank/${prefix}_cuts_${split}.jsonl.gz
mkdir -p $output_dir
zcat $oringinal_jsonl_file | jq -r '.recording.id + " " + .recording.sources[0].source' > $wav_scp_file
torchrun --nproc_per_node=8 --nnodes=1 \
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
`which s3tokenizer` --wav_scp $wav_scp_file \
--device "cuda" \
--output_dir $output_dir \
--batch_size 32 \
--num_workers 4 \
--model "speech_tokenizer_v2_25hz" # or "speech_tokenizer_v1_25hz
cat $output_dir/* > $output_dir/${prefix}_${split}_cosy_v2_tokens.json
python3 local/attach_speech_tokens.py --jsonl-prefix ${prefix}_cuts_${split} --tokens-path $output_dir/${prefix}_${split}_cosy_v2_tokens.json --manifest-dir data/fbank
done
fi