mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
update results
This commit is contained in:
parent
a54a0469a2
commit
2edaf685e1
@ -1,3 +1,10 @@
|
|||||||
|
# Results
|
||||||
|
| Model | Seed-TTS test_zh CER | Comment |
|
||||||
|
|---------------------------------------|---------------------|--------|
|
||||||
|
| [vall-e](./valle) | 4.33% | ~150M |
|
||||||
|
| [f5-tts](./f5-tts) | 3.02% (16 steps) / 2.42% (32 steps) | F5-TTS-Small Config, ~155M |
|
||||||
|
| [f5-tts-semantic-token](./f5-tts) | 1.79% (16 steps) | Using pretrained cosyvoice2 semantic tokens as inputs rather than text tokens, ~155M |
|
||||||
|
|
||||||
# Introduction
|
# Introduction
|
||||||
|
|
||||||
[**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset.
|
[**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset.
|
||||||
@ -131,6 +138,53 @@ accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-f
|
|||||||
bash local/compute_wer.sh $output_dir $manifest
|
bash local/compute_wer.sh $output_dir $manifest
|
||||||
```
|
```
|
||||||
|
|
||||||
|
# F5-TTS-Semantic-Token
|
||||||
|
|
||||||
|
./f5-tts contains the code for training F5-TTS-Semantic-Token. We replaced the text tokens in F5-TTS with pretrained cosyvoice2 semantic tokens.
|
||||||
|
|
||||||
|
We observed faster convergence and better prosody modeling results by doing this.
|
||||||
|
|
||||||
|
Generated samples and training logs of wenetspeech basic 7k hours data can be found [here](https://huggingface.co/yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic/tree/main).
|
||||||
|
|
||||||
|
Preparation:
|
||||||
|
|
||||||
|
```
|
||||||
|
# extract cosyvoice2 semantic tokens
|
||||||
|
bash prepare.sh --stage 5 --stop_stage 7
|
||||||
|
```
|
||||||
|
|
||||||
|
The training command is given below:
|
||||||
|
|
||||||
|
```
|
||||||
|
# docker: ghcr.io/swivid/f5-tts:main
|
||||||
|
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
|
||||||
|
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece
|
||||||
|
|
||||||
|
world_size=8
|
||||||
|
exp_dir=exp/f5-tts-semantic-token-small
|
||||||
|
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
|
||||||
|
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
|
||||||
|
--base-lr 1e-4 --warmup-steps 20000 --average-period 0 \
|
||||||
|
--num-epochs 10 --start-epoch 1 --start-batch 0 \
|
||||||
|
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
|
||||||
|
--exp-dir ${exp_dir} --world-size ${world_size} \
|
||||||
|
--decay-steps 600000 --prefix wenetspeech4tts_cosy_token --use-cosyvoice-semantic-token True
|
||||||
|
```
|
||||||
|
|
||||||
|
To inference with Icefall Wenetspeech4TTS trained F5-Small-Semantic-Token, use:
|
||||||
|
```
|
||||||
|
huggingface-cli login
|
||||||
|
huggingface-cli download --local-dir ${exp_dir} yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
|
||||||
|
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
|
||||||
|
|
||||||
|
split=test_zh
|
||||||
|
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
|
||||||
|
|
||||||
|
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True
|
||||||
|
bash local/compute_wer.sh $output_dir $manifest
|
||||||
|
```
|
||||||
|
|
||||||
# Credits
|
# Credits
|
||||||
- [VALL-E](https://github.com/lifeiteng/vall-e)
|
- [VALL-E](https://github.com/lifeiteng/vall-e)
|
||||||
- [F5-TTS](https://github.com/SWivid/F5-TTS)
|
- [F5-TTS](https://github.com/SWivid/F5-TTS)
|
||||||
|
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
|
||||||
|
@ -11,7 +11,14 @@ python3 f5-tts/generate_averaged_model.py \
|
|||||||
--epoch 56 \
|
--epoch 56 \
|
||||||
--avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
|
--avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
|
||||||
--exp-dir exp/f5_small
|
--exp-dir exp/f5_small
|
||||||
|
|
||||||
|
# command for text token input
|
||||||
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18
|
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18
|
||||||
|
|
||||||
|
# command for cosyvoice semantic token input
|
||||||
|
split=test_zh # seed_tts_eval test_zh
|
||||||
|
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True
|
||||||
|
|
||||||
bash local/compute_wer.sh $output_dir $manifest
|
bash local/compute_wer.sh $output_dir $manifest
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
@ -37,11 +44,12 @@ from train import (
|
|||||||
add_model_arguments,
|
add_model_arguments,
|
||||||
get_model,
|
get_model,
|
||||||
get_tokenizer,
|
get_tokenizer,
|
||||||
insert_zeros_optimized,
|
interpolate_tokens,
|
||||||
load_F5_TTS_pretrained_checkpoint,
|
load_F5_TTS_pretrained_checkpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -94,9 +102,17 @@ def get_parser():
|
|||||||
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--insert-zero",
|
"--interpolate-token",
|
||||||
action="store_true",
|
type=str2bool,
|
||||||
help="Insert zeros for CosyVoice",
|
default=True,
|
||||||
|
help="Interpolate semantic token to match mel frames for CosyVoice",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-cosyvoice-semantic-token",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use cosyvoice semantic token to replace text token.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -277,7 +293,7 @@ def get_inference_prompt_cosy_voice_huggingface(
|
|||||||
num_buckets=200,
|
num_buckets=200,
|
||||||
min_secs=3,
|
min_secs=3,
|
||||||
max_secs=40,
|
max_secs=40,
|
||||||
insert_zero=False,
|
interpolate_token=False,
|
||||||
):
|
):
|
||||||
prompts_all = []
|
prompts_all = []
|
||||||
|
|
||||||
@ -319,15 +335,15 @@ def get_inference_prompt_cosy_voice_huggingface(
|
|||||||
ref_audio = ref_audio_org
|
ref_audio = ref_audio_org
|
||||||
input_tokens = prompt_audio_tokens + audio_tokens
|
input_tokens = prompt_audio_tokens + audio_tokens
|
||||||
|
|
||||||
if insert_zero:
|
if interpolate_token:
|
||||||
input_tokens = insert_zeros_optimized(input_tokens)
|
input_tokens = interpolate_tokens(input_tokens)
|
||||||
text_list = input_tokens
|
text_list = input_tokens
|
||||||
|
|
||||||
# Duration, mel frame length
|
# Duration, mel frame length
|
||||||
ref_mel_len = ref_audio.shape[-1] // hop_length
|
ref_mel_len = ref_audio.shape[-1] // hop_length
|
||||||
|
|
||||||
total_mel_len = len(input_tokens)
|
total_mel_len = len(input_tokens)
|
||||||
if not insert_zero:
|
if not interpolate_token:
|
||||||
total_mel_len = int(total_mel_len / 4 * 15)
|
total_mel_len = int(total_mel_len / 4 * 15)
|
||||||
|
|
||||||
# to mel spectrogram
|
# to mel spectrogram
|
||||||
@ -406,6 +422,51 @@ def get_inference_prompt_cosy_voice_huggingface(
|
|||||||
return prompts_all
|
return prompts_all
|
||||||
|
|
||||||
|
|
||||||
|
def inference_speech_token(
|
||||||
|
cosyvoice,
|
||||||
|
tts_text,
|
||||||
|
prompt_text,
|
||||||
|
prompt_speech_16k,
|
||||||
|
stream=False,
|
||||||
|
speed=1.0,
|
||||||
|
text_frontend=True,
|
||||||
|
):
|
||||||
|
tokens = []
|
||||||
|
prompt_text = cosyvoice.frontend.text_normalize(
|
||||||
|
prompt_text, split=False, text_frontend=text_frontend
|
||||||
|
)
|
||||||
|
for i in cosyvoice.frontend.text_normalize(
|
||||||
|
tts_text, split=True, text_frontend=text_frontend
|
||||||
|
):
|
||||||
|
|
||||||
|
tts_text_token, tts_text_token_len = cosyvoice.frontend._extract_text_token(i)
|
||||||
|
(
|
||||||
|
prompt_text_token,
|
||||||
|
prompt_text_token_len,
|
||||||
|
) = cosyvoice.frontend._extract_text_token(prompt_text)
|
||||||
|
speech_token, speech_token_len = cosyvoice.frontend._extract_speech_token(
|
||||||
|
prompt_speech_16k
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in cosyvoice.model.llm.inference(
|
||||||
|
text=tts_text_token.to(cosyvoice.model.device),
|
||||||
|
text_len=torch.tensor([tts_text_token.shape[1]], dtype=torch.int32).to(
|
||||||
|
cosyvoice.model.device
|
||||||
|
),
|
||||||
|
prompt_text=prompt_text_token.to(cosyvoice.model.device),
|
||||||
|
prompt_text_len=torch.tensor(
|
||||||
|
[prompt_text_token.shape[1]], dtype=torch.int32
|
||||||
|
).to(cosyvoice.model.device),
|
||||||
|
prompt_speech_token=speech_token.to(cosyvoice.model.device),
|
||||||
|
prompt_speech_token_len=torch.tensor(
|
||||||
|
[speech_token.shape[1]], dtype=torch.int32
|
||||||
|
).to(cosyvoice.model.device),
|
||||||
|
embedding=None,
|
||||||
|
):
|
||||||
|
tokens.append(i)
|
||||||
|
return tokens, speech_token
|
||||||
|
|
||||||
|
|
||||||
def get_inference_prompt_cosy_voice(
|
def get_inference_prompt_cosy_voice(
|
||||||
metainfo,
|
metainfo,
|
||||||
speed=1.0,
|
speed=1.0,
|
||||||
@ -423,18 +484,21 @@ def get_inference_prompt_cosy_voice(
|
|||||||
num_buckets=200,
|
num_buckets=200,
|
||||||
min_secs=3,
|
min_secs=3,
|
||||||
max_secs=40,
|
max_secs=40,
|
||||||
insert_zero=False,
|
interpolate_token=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
# please change the path to the cosyvoice accordingly
|
||||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||||
sys.path.append("/workspace/CosyVoice")
|
sys.path.append("/workspace/CosyVoice")
|
||||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||||
|
|
||||||
|
# please download the cosyvoice model first
|
||||||
cosyvoice = CosyVoice2(
|
cosyvoice = CosyVoice2(
|
||||||
"/workspace/CosyVoice2-0.5B", load_jit=False, load_trt=False, fp16=False
|
"/workspace/CosyVoice2-0.5B", load_jit=False, load_trt=False, fp16=False
|
||||||
)
|
)
|
||||||
|
|
||||||
prompts_all = []
|
prompts_all = []
|
||||||
|
|
||||||
min_tokens = min_secs * target_sample_rate // hop_length
|
min_tokens = min_secs * target_sample_rate // hop_length
|
||||||
@ -466,8 +530,8 @@ def get_inference_prompt_cosy_voice(
|
|||||||
ref_audio_16k = resampler(ref_audio_org)
|
ref_audio_16k = resampler(ref_audio_org)
|
||||||
else:
|
else:
|
||||||
ref_audio_16k = ref_audio_org
|
ref_audio_16k = ref_audio_org
|
||||||
audio_tokens, prompt_audio_tokens = cosyvoice.inference_speech_token(
|
audio_tokens, prompt_audio_tokens = inference_speech_token(
|
||||||
gt_text, prompt_text, ref_audio_16k, stream=False
|
cosyvoice, gt_text, prompt_text, ref_audio_16k, stream=False
|
||||||
)
|
)
|
||||||
|
|
||||||
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
||||||
@ -499,8 +563,8 @@ def get_inference_prompt_cosy_voice(
|
|||||||
|
|
||||||
# convert it into a list
|
# convert it into a list
|
||||||
# input_tokens_list = input_tokens.squeeze().cpu().tolist()
|
# input_tokens_list = input_tokens.squeeze().cpu().tolist()
|
||||||
if insert_zero:
|
if interpolate_token:
|
||||||
input_tokens = insert_zeros_optimized(input_tokens)
|
input_tokens = interpolate_tokens(input_tokens)
|
||||||
text_list = input_tokens
|
text_list = input_tokens
|
||||||
|
|
||||||
# Duration, mel frame length
|
# Duration, mel frame length
|
||||||
@ -521,7 +585,7 @@ def get_inference_prompt_cosy_voice(
|
|||||||
ref_mel_len / ref_text_len * gen_text_len / speed
|
ref_mel_len / ref_text_len * gen_text_len / speed
|
||||||
)
|
)
|
||||||
total_mel_len = len(input_tokens)
|
total_mel_len = len(input_tokens)
|
||||||
if not insert_zero:
|
if not interpolate_token:
|
||||||
total_mel_len = int(total_mel_len / 4 * 15)
|
total_mel_len = int(total_mel_len / 4 * 15)
|
||||||
print(
|
print(
|
||||||
f"total_mel_len_compute: {total_mel_len_compute}, total_mel_len: {total_mel_len}"
|
f"total_mel_len_compute: {total_mel_len_compute}, total_mel_len: {total_mel_len}"
|
||||||
@ -632,33 +696,35 @@ def main():
|
|||||||
device = f"cuda:{accelerator.process_index}"
|
device = f"cuda:{accelerator.process_index}"
|
||||||
if args.manifest_file:
|
if args.manifest_file:
|
||||||
metainfo = get_seedtts_testset_metainfo(args.manifest_file)
|
metainfo = get_seedtts_testset_metainfo(args.manifest_file)
|
||||||
# prompts_all = get_inference_prompt(
|
if not args.use_cosyvoice_semantic_token:
|
||||||
# metainfo,
|
prompts_all = get_inference_prompt(
|
||||||
# speed=1.0,
|
metainfo,
|
||||||
# tokenizer="pinyin",
|
speed=1.0,
|
||||||
# target_sample_rate=24_000,
|
tokenizer="pinyin",
|
||||||
# n_mel_channels=100,
|
target_sample_rate=24_000,
|
||||||
# hop_length=256,
|
n_mel_channels=100,
|
||||||
# mel_spec_type="bigvgan",
|
hop_length=256,
|
||||||
# target_rms=0.1,
|
mel_spec_type="bigvgan",
|
||||||
# use_truth_duration=False,
|
target_rms=0.1,
|
||||||
# infer_batch_size=1,
|
use_truth_duration=False,
|
||||||
# )
|
infer_batch_size=1,
|
||||||
|
)
|
||||||
prompts_all = get_inference_prompt_cosy_voice(
|
else:
|
||||||
metainfo,
|
prompts_all = get_inference_prompt_cosy_voice(
|
||||||
speed=1.0,
|
metainfo,
|
||||||
tokenizer="pinyin",
|
speed=1.0,
|
||||||
target_sample_rate=24_000,
|
tokenizer="pinyin",
|
||||||
n_mel_channels=100,
|
target_sample_rate=24_000,
|
||||||
hop_length=256,
|
n_mel_channels=100,
|
||||||
mel_spec_type="bigvgan",
|
hop_length=256,
|
||||||
target_rms=0.1,
|
mel_spec_type="bigvgan",
|
||||||
use_truth_duration=False,
|
target_rms=0.1,
|
||||||
infer_batch_size=1,
|
use_truth_duration=False,
|
||||||
insert_zero=args.insert_zero,
|
infer_batch_size=1,
|
||||||
)
|
interpolate_token=args.interpolate_token,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
|
assert args.use_cosyvoice_semantic_token
|
||||||
dataset = datasets.load_dataset(
|
dataset = datasets.load_dataset(
|
||||||
"yuekai/seed_tts_cosy2",
|
"yuekai/seed_tts_cosy2",
|
||||||
split=args.split_name,
|
split=args.split_name,
|
||||||
@ -675,7 +741,7 @@ def main():
|
|||||||
target_rms=0.1,
|
target_rms=0.1,
|
||||||
use_truth_duration=False,
|
use_truth_duration=False,
|
||||||
infer_batch_size=1,
|
infer_batch_size=1,
|
||||||
insert_zero=args.insert_zero,
|
interpolate_token=args.interpolate_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
vocoder = BigVGANInference.from_pretrained(
|
vocoder = BigVGANInference.from_pretrained(
|
||||||
@ -712,14 +778,15 @@ def main():
|
|||||||
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
||||||
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
||||||
|
|
||||||
# concat final_text_list
|
if args.use_cosyvoice_semantic_token:
|
||||||
max_len = max([len(tokens) for tokens in final_text_list])
|
# concat final_text_list
|
||||||
# pad tokens to the same length
|
max_len = max([len(tokens) for tokens in final_text_list])
|
||||||
for i, tokens in enumerate(final_text_list):
|
# pad tokens to the same length
|
||||||
final_text_list[i] = torch.tensor(
|
for i, tokens in enumerate(final_text_list):
|
||||||
tokens + [-1] * (max_len - len(tokens)), dtype=torch.long
|
final_text_list[i] = torch.tensor(
|
||||||
)
|
tokens + [-1] * (max_len - len(tokens)), dtype=torch.long
|
||||||
final_text_list = torch.stack(final_text_list).to(device)
|
)
|
||||||
|
final_text_list = torch.stack(final_text_list).to(device)
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
@ -31,6 +31,16 @@ python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-ma
|
|||||||
--base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \
|
--base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \
|
||||||
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
|
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
|
||||||
--exp-dir ${exp_dir} --world-size ${world_size}
|
--exp-dir ${exp_dir} --world-size ${world_size}
|
||||||
|
|
||||||
|
# command for training with cosyvoice semantic token
|
||||||
|
exp_dir=exp/f5-tts-cosyvoice
|
||||||
|
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
|
||||||
|
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
|
||||||
|
--base-lr 1e-4 --warmup-steps 20000 --average-period 0 \
|
||||||
|
--num-epochs 10 --start-epoch 1 --start-batch 0 \
|
||||||
|
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
|
||||||
|
--exp-dir ${exp_dir} --world-size ${world_size} \
|
||||||
|
--decay-steps 600000 --prefix wenetspeech4tts_cosy_token --use-cosyvoice-semantic-token True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -303,6 +313,13 @@ def get_parser():
|
|||||||
help="perform OOM check on dataloader batches before starting training.",
|
help="perform OOM check on dataloader batches before starting training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-cosyvoice-semantic-token",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use cosyvoice semantic token to replace text token.",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -378,9 +395,11 @@ def get_tokenizer(vocab_file_path: str):
|
|||||||
|
|
||||||
|
|
||||||
def get_model(params):
|
def get_model(params):
|
||||||
vocab_char_map, vocab_size = get_tokenizer(params.tokens)
|
if params.use_cosyvoice_semantic_token:
|
||||||
vocab_char_map, vocab_size = None, 6561
|
# https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/file/view/master?fileName=cosyvoice.yaml&status=1#L36
|
||||||
# https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/file/view/master?fileName=cosyvoice.yaml&status=1#L36
|
vocab_char_map, vocab_size = None, 6561
|
||||||
|
else:
|
||||||
|
vocab_char_map, vocab_size = get_tokenizer(params.tokens)
|
||||||
# bigvgan 100 dim features
|
# bigvgan 100 dim features
|
||||||
n_mel_channels = 100
|
n_mel_channels = 100
|
||||||
n_fft = 1024
|
n_fft = 1024
|
||||||
@ -558,37 +577,43 @@ def save_checkpoint(
|
|||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
|
||||||
def insert_zeros_optimized(arr):
|
def interpolate_tokens(cosy_tokens, pad_token=-1):
|
||||||
|
"""Interpolate cosyvoice tokens to match bigvgan frames length"""
|
||||||
# cosyvoice, 25 tokens/sec
|
# cosyvoice, 25 tokens/sec
|
||||||
# bigvgan sample_rate/hop_length 24000/256 frames/sec
|
# bigvgan sample_rate/hop_length 24000/256 frames/sec
|
||||||
# For every 4 cosyvoice tokens, insert pad tokens to extend it to 15 tokens to match bigvgan frames length
|
# For every 4 cosyvoice tokens, insert pad tokens to extend it to 15 tokens to match bigvgan frames length
|
||||||
# We choose 4,4,4,3 to match 15 frames
|
# We choose 4,4,4,3 to match 15 frames
|
||||||
three, two = [-1] * 3, [-1] * 2
|
three, two = [pad_token] * 3, [pad_token] * 2
|
||||||
return [
|
return [
|
||||||
x for i, e in enumerate(arr) for x in ([e] + three if i % 4 < 3 else [e] + two)
|
x
|
||||||
|
for i, e in enumerate(cosy_tokens)
|
||||||
|
for x in ([e] + three if i % 4 < 3 else [e] + two)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def prepare_input(batch: dict, device: torch.device):
|
def prepare_input(
|
||||||
|
batch: dict, device: torch.device, use_cosyvoice_semantic_token: bool
|
||||||
|
):
|
||||||
"""Parse batch data"""
|
"""Parse batch data"""
|
||||||
# text_inputs = batch["text"]
|
|
||||||
# text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
|
|
||||||
|
|
||||||
mel_spec = batch["features"]
|
mel_spec = batch["features"]
|
||||||
mel_lengths = batch["features_lens"]
|
mel_lengths = batch["features_lens"]
|
||||||
|
|
||||||
semantic_tokens = []
|
if use_cosyvoice_semantic_token:
|
||||||
for i in range(len(batch["tokens"])):
|
semantic_tokens = []
|
||||||
tokens = batch["tokens"][i]
|
for i in range(len(batch["tokens"])):
|
||||||
# tokens = insert_zeros_optimized(tokens)
|
tokens = batch["tokens"][i]
|
||||||
semantic_tokens.append(tokens)
|
tokens = interpolate_tokens(tokens)
|
||||||
# pad to the same length, B,T, with pad value -1
|
semantic_tokens.append(tokens)
|
||||||
max_len = max([len(tokens) for tokens in semantic_tokens])
|
# pad to the same length, B,T, with pad value -1
|
||||||
text_inputs = torch.full((len(semantic_tokens), max_len), -1, dtype=torch.long).to(
|
max_len = max([len(tokens) for tokens in semantic_tokens])
|
||||||
device
|
text_inputs = torch.full(
|
||||||
)
|
(len(semantic_tokens), max_len), -1, dtype=torch.long
|
||||||
for i, tokens in enumerate(semantic_tokens):
|
).to(device)
|
||||||
text_inputs[i, : len(tokens)] = torch.tensor(tokens, dtype=torch.long)
|
for i, tokens in enumerate(semantic_tokens):
|
||||||
|
text_inputs[i, : len(tokens)] = torch.tensor(tokens, dtype=torch.long)
|
||||||
|
else:
|
||||||
|
text_inputs = batch["text"]
|
||||||
|
text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
|
||||||
|
|
||||||
return text_inputs, mel_spec.to(device), mel_lengths.to(device)
|
return text_inputs, mel_spec.to(device), mel_lengths.to(device)
|
||||||
|
|
||||||
@ -619,7 +644,11 @@ def compute_loss(
|
|||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
(text_inputs, mel_spec, mel_lengths) = prepare_input(batch, device=device)
|
(text_inputs, mel_spec, mel_lengths) = prepare_input(
|
||||||
|
batch,
|
||||||
|
device=device,
|
||||||
|
use_cosyvoice_semantic_token=params.use_cosyvoice_semantic_token,
|
||||||
|
)
|
||||||
# at entry, TextTokens is (N, P)
|
# at entry, TextTokens is (N, P)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user