mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add F5-TTS with semantic token training results (#1880)
* add cosy token * update inference code * add extract cosy token * update results * add requirements.txt * update readme --------- Co-authored-by: yuekaiz <yuekaiz@h20-7.cm.cluster> Co-authored-by: yuekaiz <yuekaiz@mgmt1-login.cm.cluster>
This commit is contained in:
parent
da597ad782
commit
2ba665abca
@ -1,3 +1,10 @@
|
||||
# Results
|
||||
| Model | Seed-TTS test_zh CER | Comment |
|
||||
|---------------------------------------|---------------------|--------|
|
||||
| [vall-e](./valle) | 4.33% | ~150M |
|
||||
| [f5-tts](./f5-tts) | 3.02% (16 steps) / 2.42% (32 steps) | F5-TTS-Small Config, ~155M |
|
||||
| [f5-tts-semantic-token](./f5-tts) | 1.79% (16 steps) | Using pretrained cosyvoice2 semantic tokens as inputs rather than text tokens, ~155M |
|
||||
|
||||
# Introduction
|
||||
|
||||
[**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset.
|
||||
@ -131,6 +138,51 @@ accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-f
|
||||
bash local/compute_wer.sh $output_dir $manifest
|
||||
```
|
||||
|
||||
# F5-TTS-Semantic-Token
|
||||
|
||||
./f5-tts contains the code for training F5-TTS-Semantic-Token. We replaced the text tokens in F5-TTS with pretrained cosyvoice2 semantic tokens. During inference, we use the pretrained CosyVoice2 LLM to predict the semantic tokens for target audios. We observed that this approach leads to faster convergence and improved prosody modeling results.
|
||||
|
||||
Generated samples and training logs of wenetspeech basic 7k hours data can be found [here](https://huggingface.co/yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic/tree/main).
|
||||
|
||||
Preparation:
|
||||
|
||||
```
|
||||
# extract cosyvoice2 semantic tokens
|
||||
bash prepare.sh --stage 5 --stop_stage 7
|
||||
```
|
||||
|
||||
The training command is given below:
|
||||
|
||||
```
|
||||
# docker: ghcr.io/swivid/f5-tts:main
|
||||
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece
|
||||
|
||||
world_size=8
|
||||
exp_dir=exp/f5-tts-semantic-token-small
|
||||
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
|
||||
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
|
||||
--base-lr 1e-4 --warmup-steps 20000 --average-period 0 \
|
||||
--num-epochs 10 --start-epoch 1 --start-batch 0 \
|
||||
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
|
||||
--exp-dir ${exp_dir} --world-size ${world_size} \
|
||||
--decay-steps 600000 --prefix wenetspeech4tts_cosy_token --use-cosyvoice-semantic-token True
|
||||
```
|
||||
|
||||
To inference with Icefall Wenetspeech4TTS trained F5-Small-Semantic-Token, use:
|
||||
```
|
||||
huggingface-cli login
|
||||
huggingface-cli download --local-dir ${exp_dir} yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
|
||||
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
|
||||
|
||||
split=test_zh
|
||||
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
|
||||
|
||||
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True
|
||||
bash local/compute_wer.sh $output_dir $manifest
|
||||
```
|
||||
|
||||
# Credits
|
||||
- [VALL-E](https://github.com/lifeiteng/vall-e)
|
||||
- [F5-TTS](https://github.com/SWivid/F5-TTS)
|
||||
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
|
||||
|
@ -11,7 +11,14 @@ python3 f5-tts/generate_averaged_model.py \
|
||||
--epoch 56 \
|
||||
--avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
|
||||
--exp-dir exp/f5_small
|
||||
|
||||
# command for text token input
|
||||
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18
|
||||
|
||||
# command for cosyvoice semantic token input
|
||||
split=test_zh # seed_tts_eval test_zh
|
||||
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True
|
||||
|
||||
bash local/compute_wer.sh $output_dir $manifest
|
||||
"""
|
||||
import argparse
|
||||
@ -22,6 +29,7 @@ import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
@ -36,10 +44,12 @@ from train import (
|
||||
add_model_arguments,
|
||||
get_model,
|
||||
get_tokenizer,
|
||||
interpolate_tokens,
|
||||
load_F5_TTS_pretrained_checkpoint,
|
||||
)
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -78,7 +88,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--manifest-file",
|
||||
type=str,
|
||||
default="/path/seed_tts_eval/seedtts_testset/zh/meta.lst",
|
||||
default=None,
|
||||
help="The manifest file in seed_tts_eval format",
|
||||
)
|
||||
|
||||
@ -90,6 +100,29 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
||||
|
||||
parser.add_argument(
|
||||
"--interpolate-token",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Interpolate semantic token to match mel frames for CosyVoice",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-cosyvoice-semantic-token",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use cosyvoice semantic token to replace text token.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default="wenetspeech4tts",
|
||||
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
||||
help="huggingface dataset split name",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
return parser.parse_args()
|
||||
|
||||
@ -243,6 +276,392 @@ def get_inference_prompt(
|
||||
return prompts_all
|
||||
|
||||
|
||||
def get_inference_prompt_cosy_voice_huggingface(
|
||||
dataset,
|
||||
speed=1.0,
|
||||
tokenizer="pinyin",
|
||||
polyphone=True,
|
||||
target_sample_rate=24000,
|
||||
n_fft=1024,
|
||||
win_length=1024,
|
||||
n_mel_channels=100,
|
||||
hop_length=256,
|
||||
mel_spec_type="bigvgan",
|
||||
target_rms=0.1,
|
||||
use_truth_duration=False,
|
||||
infer_batch_size=1,
|
||||
num_buckets=200,
|
||||
min_secs=3,
|
||||
max_secs=40,
|
||||
interpolate_token=False,
|
||||
):
|
||||
prompts_all = []
|
||||
|
||||
min_tokens = min_secs * target_sample_rate // hop_length
|
||||
max_tokens = max_secs * target_sample_rate // hop_length
|
||||
|
||||
batch_accum = [0] * num_buckets
|
||||
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
|
||||
[[] for _ in range(num_buckets)] for _ in range(6)
|
||||
)
|
||||
|
||||
mel_spectrogram = MelSpec(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
target_sample_rate=target_sample_rate,
|
||||
mel_spec_type=mel_spec_type,
|
||||
)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
utt = dataset[i]["id"]
|
||||
ref_audio_org, ref_sr = (
|
||||
dataset[i]["prompt_audio"]["array"],
|
||||
dataset[i]["prompt_audio"]["sampling_rate"],
|
||||
)
|
||||
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
|
||||
audio_tokens = dataset[i]["target_audio_cosy2_tokens"]
|
||||
prompt_audio_tokens = dataset[i]["prompt_audio_cosy2_tokens"]
|
||||
|
||||
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
||||
if ref_rms < target_rms:
|
||||
ref_audio_org = ref_audio_org * target_rms / ref_rms
|
||||
|
||||
if ref_sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||
ref_audio = resampler(ref_audio_org)
|
||||
else:
|
||||
ref_audio = ref_audio_org
|
||||
input_tokens = prompt_audio_tokens + audio_tokens
|
||||
|
||||
if interpolate_token:
|
||||
input_tokens = interpolate_tokens(input_tokens)
|
||||
text_list = input_tokens
|
||||
|
||||
# Duration, mel frame length
|
||||
ref_mel_len = ref_audio.shape[-1] // hop_length
|
||||
|
||||
total_mel_len = len(input_tokens)
|
||||
if not interpolate_token:
|
||||
total_mel_len = int(total_mel_len / 4 * 15)
|
||||
|
||||
# to mel spectrogram
|
||||
ref_mel = mel_spectrogram(ref_audio)
|
||||
ref_mel = ref_mel.squeeze(0)
|
||||
|
||||
# deal with batch
|
||||
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
||||
if total_mel_len > max_tokens:
|
||||
print(
|
||||
f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
||||
)
|
||||
continue
|
||||
assert (
|
||||
min_tokens <= total_mel_len <= max_tokens
|
||||
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
||||
bucket_i = math.floor(
|
||||
(total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets
|
||||
)
|
||||
|
||||
utts[bucket_i].append(utt)
|
||||
ref_rms_list[bucket_i].append(ref_rms)
|
||||
ref_mels[bucket_i].append(ref_mel)
|
||||
ref_mel_lens[bucket_i].append(ref_mel_len)
|
||||
total_mel_lens[bucket_i].append(total_mel_len)
|
||||
# final_text_list[bucket_i].extend(text_list)
|
||||
final_text_list[bucket_i].append(text_list)
|
||||
|
||||
batch_accum[bucket_i] += total_mel_len
|
||||
|
||||
if batch_accum[bucket_i] >= infer_batch_size:
|
||||
prompts_all.append(
|
||||
(
|
||||
utts[bucket_i],
|
||||
ref_rms_list[bucket_i],
|
||||
padded_mel_batch(ref_mels[bucket_i]),
|
||||
ref_mel_lens[bucket_i],
|
||||
total_mel_lens[bucket_i],
|
||||
final_text_list[bucket_i],
|
||||
)
|
||||
)
|
||||
batch_accum[bucket_i] = 0
|
||||
(
|
||||
utts[bucket_i],
|
||||
ref_rms_list[bucket_i],
|
||||
ref_mels[bucket_i],
|
||||
ref_mel_lens[bucket_i],
|
||||
total_mel_lens[bucket_i],
|
||||
final_text_list[bucket_i],
|
||||
) = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
# add residual
|
||||
for bucket_i, bucket_frames in enumerate(batch_accum):
|
||||
if bucket_frames > 0:
|
||||
prompts_all.append(
|
||||
(
|
||||
utts[bucket_i],
|
||||
ref_rms_list[bucket_i],
|
||||
padded_mel_batch(ref_mels[bucket_i]),
|
||||
ref_mel_lens[bucket_i],
|
||||
total_mel_lens[bucket_i],
|
||||
final_text_list[bucket_i],
|
||||
)
|
||||
)
|
||||
# not only leave easy work for last workers
|
||||
random.seed(666)
|
||||
random.shuffle(prompts_all)
|
||||
|
||||
return prompts_all
|
||||
|
||||
|
||||
def inference_speech_token(
|
||||
cosyvoice,
|
||||
tts_text,
|
||||
prompt_text,
|
||||
prompt_speech_16k,
|
||||
stream=False,
|
||||
speed=1.0,
|
||||
text_frontend=True,
|
||||
):
|
||||
tokens = []
|
||||
prompt_text = cosyvoice.frontend.text_normalize(
|
||||
prompt_text, split=False, text_frontend=text_frontend
|
||||
)
|
||||
for i in cosyvoice.frontend.text_normalize(
|
||||
tts_text, split=True, text_frontend=text_frontend
|
||||
):
|
||||
|
||||
tts_text_token, tts_text_token_len = cosyvoice.frontend._extract_text_token(i)
|
||||
(
|
||||
prompt_text_token,
|
||||
prompt_text_token_len,
|
||||
) = cosyvoice.frontend._extract_text_token(prompt_text)
|
||||
speech_token, speech_token_len = cosyvoice.frontend._extract_speech_token(
|
||||
prompt_speech_16k
|
||||
)
|
||||
|
||||
for i in cosyvoice.model.llm.inference(
|
||||
text=tts_text_token.to(cosyvoice.model.device),
|
||||
text_len=torch.tensor([tts_text_token.shape[1]], dtype=torch.int32).to(
|
||||
cosyvoice.model.device
|
||||
),
|
||||
prompt_text=prompt_text_token.to(cosyvoice.model.device),
|
||||
prompt_text_len=torch.tensor(
|
||||
[prompt_text_token.shape[1]], dtype=torch.int32
|
||||
).to(cosyvoice.model.device),
|
||||
prompt_speech_token=speech_token.to(cosyvoice.model.device),
|
||||
prompt_speech_token_len=torch.tensor(
|
||||
[speech_token.shape[1]], dtype=torch.int32
|
||||
).to(cosyvoice.model.device),
|
||||
embedding=None,
|
||||
):
|
||||
tokens.append(i)
|
||||
return tokens, speech_token
|
||||
|
||||
|
||||
def get_inference_prompt_cosy_voice(
|
||||
metainfo,
|
||||
speed=1.0,
|
||||
tokenizer="pinyin",
|
||||
polyphone=True,
|
||||
target_sample_rate=24000,
|
||||
n_fft=1024,
|
||||
win_length=1024,
|
||||
n_mel_channels=100,
|
||||
hop_length=256,
|
||||
mel_spec_type="bigvgan",
|
||||
target_rms=0.1,
|
||||
use_truth_duration=False,
|
||||
infer_batch_size=1,
|
||||
num_buckets=200,
|
||||
min_secs=3,
|
||||
max_secs=40,
|
||||
interpolate_token=False,
|
||||
):
|
||||
|
||||
import sys
|
||||
|
||||
# please change the path to the cosyvoice accordingly
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
sys.path.append("/workspace/CosyVoice")
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
|
||||
# please download the cosyvoice model first
|
||||
cosyvoice = CosyVoice2(
|
||||
"/workspace/CosyVoice2-0.5B", load_jit=False, load_trt=False, fp16=False
|
||||
)
|
||||
|
||||
prompts_all = []
|
||||
|
||||
min_tokens = min_secs * target_sample_rate // hop_length
|
||||
max_tokens = max_secs * target_sample_rate // hop_length
|
||||
|
||||
batch_accum = [0] * num_buckets
|
||||
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
|
||||
[[] for _ in range(num_buckets)] for _ in range(6)
|
||||
)
|
||||
|
||||
mel_spectrogram = MelSpec(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
target_sample_rate=target_sample_rate,
|
||||
mel_spec_type=mel_spec_type,
|
||||
)
|
||||
|
||||
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(
|
||||
metainfo, desc="Processing prompts..."
|
||||
):
|
||||
# Audio
|
||||
ref_audio_org, ref_sr = torchaudio.load(prompt_wav)
|
||||
|
||||
# cosy voice
|
||||
if ref_sr != 16000:
|
||||
resampler = torchaudio.transforms.Resample(ref_sr, 16000)
|
||||
ref_audio_16k = resampler(ref_audio_org)
|
||||
else:
|
||||
ref_audio_16k = ref_audio_org
|
||||
audio_tokens, prompt_audio_tokens = inference_speech_token(
|
||||
cosyvoice, gt_text, prompt_text, ref_audio_16k, stream=False
|
||||
)
|
||||
|
||||
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
||||
if ref_rms < target_rms:
|
||||
ref_audio_org = ref_audio_org * target_rms / ref_rms
|
||||
assert (
|
||||
ref_audio_org.shape[-1] > 5000
|
||||
), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
|
||||
if ref_sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||
ref_audio = resampler(ref_audio_org)
|
||||
else:
|
||||
ref_audio = ref_audio_org
|
||||
|
||||
# Text
|
||||
# if len(prompt_text[-1].encode("utf-8")) == 1:
|
||||
# prompt_text = prompt_text + " "
|
||||
# text = [prompt_text + gt_text]
|
||||
# if tokenizer == "pinyin":
|
||||
# text_list = convert_char_to_pinyin(text, polyphone=polyphone)
|
||||
# else:
|
||||
# text_list = text
|
||||
|
||||
# concat two tensors: prompt audio tokens with audio tokens --> shape 1, prompt_audio_tokens + audio_tokens
|
||||
# prompt_audio_tokens shape 1, prompt_audio_tokens
|
||||
# audio_tokens shape 1, audio_tokens
|
||||
prompt_audio_tokens = prompt_audio_tokens.squeeze().cpu().tolist()
|
||||
input_tokens = prompt_audio_tokens + audio_tokens
|
||||
|
||||
# convert it into a list
|
||||
# input_tokens_list = input_tokens.squeeze().cpu().tolist()
|
||||
if interpolate_token:
|
||||
input_tokens = interpolate_tokens(input_tokens)
|
||||
text_list = input_tokens
|
||||
|
||||
# Duration, mel frame length
|
||||
ref_mel_len = ref_audio.shape[-1] // hop_length
|
||||
if use_truth_duration:
|
||||
gt_audio, gt_sr = torchaudio.load(gt_wav)
|
||||
if gt_sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
|
||||
gt_audio = resampler(gt_audio)
|
||||
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
|
||||
|
||||
# # test vocoder resynthesis
|
||||
# ref_audio = gt_audio
|
||||
else:
|
||||
ref_text_len = len(prompt_text.encode("utf-8"))
|
||||
gen_text_len = len(gt_text.encode("utf-8"))
|
||||
total_mel_len_compute = ref_mel_len + int(
|
||||
ref_mel_len / ref_text_len * gen_text_len / speed
|
||||
)
|
||||
total_mel_len = len(input_tokens)
|
||||
if not interpolate_token:
|
||||
total_mel_len = int(total_mel_len / 4 * 15)
|
||||
print(
|
||||
f"total_mel_len_compute: {total_mel_len_compute}, total_mel_len: {total_mel_len}"
|
||||
)
|
||||
|
||||
# to mel spectrogram
|
||||
ref_mel = mel_spectrogram(ref_audio)
|
||||
ref_mel = ref_mel.squeeze(0)
|
||||
|
||||
# deal with batch
|
||||
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
||||
assert (
|
||||
min_tokens <= total_mel_len <= max_tokens
|
||||
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
||||
bucket_i = math.floor(
|
||||
(total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets
|
||||
)
|
||||
|
||||
utts[bucket_i].append(utt)
|
||||
ref_rms_list[bucket_i].append(ref_rms)
|
||||
ref_mels[bucket_i].append(ref_mel)
|
||||
ref_mel_lens[bucket_i].append(ref_mel_len)
|
||||
total_mel_lens[bucket_i].append(total_mel_len)
|
||||
# final_text_list[bucket_i].extend(text_list)
|
||||
final_text_list[bucket_i].append(text_list)
|
||||
|
||||
batch_accum[bucket_i] += total_mel_len
|
||||
|
||||
if batch_accum[bucket_i] >= infer_batch_size:
|
||||
prompts_all.append(
|
||||
(
|
||||
utts[bucket_i],
|
||||
ref_rms_list[bucket_i],
|
||||
padded_mel_batch(ref_mels[bucket_i]),
|
||||
ref_mel_lens[bucket_i],
|
||||
total_mel_lens[bucket_i],
|
||||
final_text_list[bucket_i],
|
||||
)
|
||||
)
|
||||
batch_accum[bucket_i] = 0
|
||||
(
|
||||
utts[bucket_i],
|
||||
ref_rms_list[bucket_i],
|
||||
ref_mels[bucket_i],
|
||||
ref_mel_lens[bucket_i],
|
||||
total_mel_lens[bucket_i],
|
||||
final_text_list[bucket_i],
|
||||
) = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
# add residual
|
||||
for bucket_i, bucket_frames in enumerate(batch_accum):
|
||||
if bucket_frames > 0:
|
||||
prompts_all.append(
|
||||
(
|
||||
utts[bucket_i],
|
||||
ref_rms_list[bucket_i],
|
||||
padded_mel_batch(ref_mels[bucket_i]),
|
||||
ref_mel_lens[bucket_i],
|
||||
total_mel_lens[bucket_i],
|
||||
final_text_list[bucket_i],
|
||||
)
|
||||
)
|
||||
# not only leave easy work for last workers
|
||||
random.seed(666)
|
||||
random.shuffle(prompts_all)
|
||||
|
||||
return prompts_all
|
||||
|
||||
|
||||
def padded_mel_batch(ref_mels):
|
||||
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
||||
padded_ref_mels = []
|
||||
@ -275,8 +694,9 @@ def main():
|
||||
|
||||
accelerator = Accelerator()
|
||||
device = f"cuda:{accelerator.process_index}"
|
||||
|
||||
if args.manifest_file:
|
||||
metainfo = get_seedtts_testset_metainfo(args.manifest_file)
|
||||
if not args.use_cosyvoice_semantic_token:
|
||||
prompts_all = get_inference_prompt(
|
||||
metainfo,
|
||||
speed=1.0,
|
||||
@ -289,6 +709,40 @@ def main():
|
||||
use_truth_duration=False,
|
||||
infer_batch_size=1,
|
||||
)
|
||||
else:
|
||||
prompts_all = get_inference_prompt_cosy_voice(
|
||||
metainfo,
|
||||
speed=1.0,
|
||||
tokenizer="pinyin",
|
||||
target_sample_rate=24_000,
|
||||
n_mel_channels=100,
|
||||
hop_length=256,
|
||||
mel_spec_type="bigvgan",
|
||||
target_rms=0.1,
|
||||
use_truth_duration=False,
|
||||
infer_batch_size=1,
|
||||
interpolate_token=args.interpolate_token,
|
||||
)
|
||||
else:
|
||||
assert args.use_cosyvoice_semantic_token
|
||||
dataset = datasets.load_dataset(
|
||||
"yuekai/seed_tts_cosy2",
|
||||
split=args.split_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
prompts_all = get_inference_prompt_cosy_voice_huggingface(
|
||||
dataset,
|
||||
speed=1.0,
|
||||
tokenizer="pinyin",
|
||||
target_sample_rate=24_000,
|
||||
n_mel_channels=100,
|
||||
hop_length=256,
|
||||
mel_spec_type="bigvgan",
|
||||
target_rms=0.1,
|
||||
use_truth_duration=False,
|
||||
infer_batch_size=1,
|
||||
interpolate_token=args.interpolate_token,
|
||||
)
|
||||
|
||||
vocoder = BigVGANInference.from_pretrained(
|
||||
"./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False
|
||||
@ -324,6 +778,16 @@ def main():
|
||||
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
||||
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
||||
|
||||
if args.use_cosyvoice_semantic_token:
|
||||
# concat final_text_list
|
||||
max_len = max([len(tokens) for tokens in final_text_list])
|
||||
# pad tokens to the same length
|
||||
for i, tokens in enumerate(final_text_list):
|
||||
final_text_list[i] = torch.tensor(
|
||||
tokens + [-1] * (max_len - len(tokens)), dtype=torch.long
|
||||
)
|
||||
final_text_list = torch.stack(final_text_list).to(device)
|
||||
|
||||
# Inference
|
||||
with torch.inference_mode():
|
||||
generated, _ = model.sample(
|
||||
|
36
egs/wenetspeech4tts/TTS/f5-tts/requirements.txt
Normal file
36
egs/wenetspeech4tts/TTS/f5-tts/requirements.txt
Normal file
@ -0,0 +1,36 @@
|
||||
# F5-TTS
|
||||
accelerate>=0.33.0
|
||||
bitsandbytes>0.37.0
|
||||
cached_path
|
||||
click
|
||||
datasets
|
||||
ema_pytorch>=0.5.2
|
||||
gradio>=3.45.2
|
||||
hydra-core>=1.3.0
|
||||
jieba
|
||||
librosa
|
||||
matplotlib
|
||||
numpy<=1.26.4
|
||||
pydub
|
||||
pypinyin
|
||||
safetensors
|
||||
soundfile
|
||||
tomli
|
||||
torch>=2.0.0
|
||||
torchaudio>=2.0.0
|
||||
torchdiffeq
|
||||
tqdm>=4.65.0
|
||||
transformers
|
||||
x_transformers>=1.31.14
|
||||
|
||||
# icefall
|
||||
kaldialign
|
||||
lhotse
|
||||
tensorboard
|
||||
bigvganinference
|
||||
sentencepiece
|
||||
sherpa-onnx
|
||||
k2
|
||||
|
||||
# semantic experiment
|
||||
s3tokenizer
|
@ -82,9 +82,12 @@ class SpeechSynthesisDataset(torch.utils.data.Dataset):
|
||||
text = [cut.supervisions[0].text for cut in cuts]
|
||||
batch["text"] = text
|
||||
|
||||
if self.return_tokens:
|
||||
if self.return_tokens and "speech_tokens" in cuts[0].supervisions[0].custom:
|
||||
# tokens = [cut.tokens for cut in cuts]
|
||||
tokens = [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts]
|
||||
# tokens = [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts]
|
||||
tokens = [cut.supervisions[0].custom["speech_tokens"] for cut in cuts]
|
||||
# change str into list
|
||||
tokens = [list(map(int, token.split())) for token in tokens]
|
||||
batch["tokens"] = tokens
|
||||
|
||||
if self.return_spk_ids:
|
||||
|
@ -31,6 +31,16 @@ python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-ma
|
||||
--base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \
|
||||
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
|
||||
--exp-dir ${exp_dir} --world-size ${world_size}
|
||||
|
||||
# command for training with cosyvoice semantic token
|
||||
exp_dir=exp/f5-tts-cosyvoice
|
||||
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
|
||||
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
|
||||
--base-lr 1e-4 --warmup-steps 20000 --average-period 0 \
|
||||
--num-epochs 10 --start-epoch 1 --start-batch 0 \
|
||||
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
|
||||
--exp-dir ${exp_dir} --world-size ${world_size} \
|
||||
--decay-steps 600000 --prefix wenetspeech4tts_cosy_token --use-cosyvoice-semantic-token True
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -303,6 +313,13 @@ def get_parser():
|
||||
help="perform OOM check on dataloader batches before starting training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-cosyvoice-semantic-token",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use cosyvoice semantic token to replace text token.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -378,6 +395,10 @@ def get_tokenizer(vocab_file_path: str):
|
||||
|
||||
|
||||
def get_model(params):
|
||||
if params.use_cosyvoice_semantic_token:
|
||||
# https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/file/view/master?fileName=cosyvoice.yaml&status=1#L36
|
||||
vocab_char_map, vocab_size = None, 6561
|
||||
else:
|
||||
vocab_char_map, vocab_size = get_tokenizer(params.tokens)
|
||||
# bigvgan 100 dim features
|
||||
n_mel_channels = 100
|
||||
@ -556,14 +577,44 @@ def save_checkpoint(
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def prepare_input(batch: dict, device: torch.device):
|
||||
"""Parse batch data"""
|
||||
text_inputs = batch["text"]
|
||||
# texts.extend(convert_char_to_pinyin([text], polyphone=true))
|
||||
text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
|
||||
def interpolate_tokens(cosy_tokens, pad_token=-1):
|
||||
"""Interpolate cosyvoice tokens to match bigvgan frames length"""
|
||||
# cosyvoice, 25 tokens/sec
|
||||
# bigvgan sample_rate/hop_length 24000/256 frames/sec
|
||||
# For every 4 cosyvoice tokens, insert pad tokens to extend it to 15 tokens to match bigvgan frames length
|
||||
# We choose 4,4,4,3 to match 15 frames
|
||||
three, two = [pad_token] * 3, [pad_token] * 2
|
||||
return [
|
||||
x
|
||||
for i, e in enumerate(cosy_tokens)
|
||||
for x in ([e] + three if i % 4 < 3 else [e] + two)
|
||||
]
|
||||
|
||||
|
||||
def prepare_input(
|
||||
batch: dict, device: torch.device, use_cosyvoice_semantic_token: bool
|
||||
):
|
||||
"""Parse batch data"""
|
||||
mel_spec = batch["features"]
|
||||
mel_lengths = batch["features_lens"]
|
||||
|
||||
if use_cosyvoice_semantic_token:
|
||||
semantic_tokens = []
|
||||
for i in range(len(batch["tokens"])):
|
||||
tokens = batch["tokens"][i]
|
||||
tokens = interpolate_tokens(tokens)
|
||||
semantic_tokens.append(tokens)
|
||||
# pad to the same length, B,T, with pad value -1
|
||||
max_len = max([len(tokens) for tokens in semantic_tokens])
|
||||
text_inputs = torch.full(
|
||||
(len(semantic_tokens), max_len), -1, dtype=torch.long
|
||||
).to(device)
|
||||
for i, tokens in enumerate(semantic_tokens):
|
||||
text_inputs[i, : len(tokens)] = torch.tensor(tokens, dtype=torch.long)
|
||||
else:
|
||||
text_inputs = batch["text"]
|
||||
text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
|
||||
|
||||
return text_inputs, mel_spec.to(device), mel_lengths.to(device)
|
||||
|
||||
|
||||
@ -593,7 +644,11 @@ def compute_loss(
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
(text_inputs, mel_spec, mel_lengths) = prepare_input(batch, device=device)
|
||||
(text_inputs, mel_spec, mel_lengths) = prepare_input(
|
||||
batch,
|
||||
device=device,
|
||||
use_cosyvoice_semantic_token=params.use_cosyvoice_semantic_token,
|
||||
)
|
||||
# at entry, TextTokens is (N, P)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
|
@ -174,7 +174,7 @@ class TtsDataModule:
|
||||
logging.info("About to create train dataset")
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=True,
|
||||
return_tokens=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
@ -234,7 +234,7 @@ class TtsDataModule:
|
||||
else:
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=True,
|
||||
return_tokens=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
@ -265,7 +265,7 @@ class TtsDataModule:
|
||||
else:
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=True,
|
||||
return_tokens=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
108
egs/wenetspeech4tts/TTS/local/attach_speech_tokens.py
Normal file
108
egs/wenetspeech4tts/TTS/local/attach_speech_tokens.py
Normal file
@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 author: Yuekai Zhang
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
|
||||
import s3tokenizer
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--manifest-dir",
|
||||
type=str,
|
||||
default="data/fbank",
|
||||
help="Directory to store the manifest files",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jsonl-prefix",
|
||||
type=str,
|
||||
default="wenetspeech4tts_cuts_valid",
|
||||
help="The training subset for wenetspeech.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens-path",
|
||||
type=str,
|
||||
default="./s3_tokens_valid/wenetspeech4tts_valid.json",
|
||||
help="json file containing the speech tokens",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_speech_tokens(tokens_path):
|
||||
id2tokens = {}
|
||||
with open(tokens_path, "r") as fin:
|
||||
for line in fin:
|
||||
line = json.loads(line)
|
||||
id2tokens[line["key"]] = " ".join(map(str, line["code"]))
|
||||
return id2tokens
|
||||
|
||||
|
||||
def attach_manifest(manifest, fixed_manifest_path, id2tokens):
|
||||
with CutSet.open_writer(fixed_manifest_path) as manifest_writer:
|
||||
fixed_item = 0
|
||||
for i, cut in enumerate(tqdm(manifest)):
|
||||
cut_id = cut.supervisions[0].id
|
||||
if cut_id in id2tokens:
|
||||
code = id2tokens[cut_id]
|
||||
cut.supervisions[0].custom = {
|
||||
**cut.supervisions[0].custom,
|
||||
**{"speech_tokens": code},
|
||||
}
|
||||
else:
|
||||
print(f"cut_id {cut_id} not in id2tokens")
|
||||
fixed_item += 1
|
||||
manifest_writer.write(cut)
|
||||
logging.info(f"Fixed {fixed_item} items in the manifest")
|
||||
|
||||
|
||||
def main():
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
manifest_path = args.manifest_dir + "/" + f"{args.jsonl_prefix}.jsonl.gz"
|
||||
attached_manifest_path = (
|
||||
args.manifest_dir + "/" + f"{args.jsonl_prefix}_attached_cosyvoice_v2.jsonl.gz"
|
||||
)
|
||||
logging.info(f"Loading manifest from {manifest_path}")
|
||||
cuts_manifest = load_manifest_lazy(manifest_path)
|
||||
logging.info(f"Loading manifest from {manifest_path} done")
|
||||
id2tokens = get_speech_tokens(args.tokens_path)
|
||||
logging.info(f"Loaded id2tokens with {len(id2tokens)} entries")
|
||||
|
||||
attach_manifest(cuts_manifest, attached_manifest_path, id2tokens)
|
||||
logging.info(
|
||||
f"Manifest with speech tokens attached is saved to {attached_manifest_path}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -111,7 +111,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 7: Split the ${prefix} cuts into train, valid and test sets (used by ./f5-tts)"
|
||||
log "Stage 6: Split the ${prefix} cuts into train, valid and test sets (used by ./f5-tts)"
|
||||
if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then
|
||||
echo "Combining ${prefix} cuts"
|
||||
pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz")
|
||||
@ -139,3 +139,27 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
touch data/fbank/.${prefix}_split.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Extract cosyvoice2 FSQ token (used by ./f5-tts semantic token experiment)"
|
||||
split_name=("valid" "test" "train")
|
||||
for split in "${split_name[@]}"; do
|
||||
echo "Processing $split"
|
||||
wav_scp_file=wav_${split}.scp
|
||||
output_dir="./cosy_v2_tokens_${split}"
|
||||
oringinal_jsonl_file=data/fbank/${prefix}_cuts_${split}.jsonl.gz
|
||||
mkdir -p $output_dir
|
||||
zcat $oringinal_jsonl_file | jq -r '.recording.id + " " + .recording.sources[0].source' > $wav_scp_file
|
||||
torchrun --nproc_per_node=8 --nnodes=1 \
|
||||
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
||||
`which s3tokenizer` --wav_scp $wav_scp_file \
|
||||
--device "cuda" \
|
||||
--output_dir $output_dir \
|
||||
--batch_size 32 \
|
||||
--num_workers 4 \
|
||||
--model "speech_tokenizer_v2_25hz" # or "speech_tokenizer_v1_25hz
|
||||
|
||||
cat $output_dir/* > $output_dir/${prefix}_${split}_cosy_v2_tokens.json
|
||||
python3 local/attach_speech_tokens.py --jsonl-prefix ${prefix}_cuts_${split} --tokens-path $output_dir/${prefix}_${split}_cosy_v2_tokens.json --manifest-dir data/fbank
|
||||
done
|
||||
fi
|
||||
|
Loading…
x
Reference in New Issue
Block a user