From d2b473ad99e8b00fb3048dd6fe742af3a139c9c3 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 28 Feb 2025 09:54:22 +0000 Subject: [PATCH] add eval seed tts --- egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py | 347 +++++++++++++++++++ egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py | 108 ++++++ 2 files changed, 455 insertions(+) create mode 100644 egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py create mode 100644 egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py b/egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py new file mode 100644 index 000000000..59e222a74 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py @@ -0,0 +1,347 @@ +# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Example Usage +cpu: + +s3tokenizer --data_dir xxx.scp \ + --device "cpu" \ + --output_dir "./" \ + --batch_size 32 + +gpu: + +torchrun --nproc_per_node=8 --nnodes=1 \ + --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ + `which s3tokenizer` --data_dir xxx.scp \ + --device "cuda" \ + --output_dir "./" \ + --batch_size 32 + +""" + +import argparse +import json +import os +from pathlib import Path + +import s3tokenizer +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torchaudio +from bigvganinference import BigVGANInference +from datasets import load_dataset +from lhotse.serialization import load_jsonl +from llm_tts import LLMTTS +from model.modules import MelSpec +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from tqdm import tqdm +from train import ( + add_model_arguments, + get_model, + get_tokenizer, + interpolate_tokens, + load_F5_TTS_pretrained_checkpoint, +) + +from icefall.checkpoint import load_checkpoint + +TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" + + +def get_args(): + parser = argparse.ArgumentParser(description="extract speech code") + parser.add_argument( + "--s3-tokenizer-name", + required=False, + type=str, + choices=[ + "speech_tokenizer_v1", + "speech_tokenizer_v1_25hz", + "speech_tokenizer_v2_25hz", + ], + help="model version", + ) + parser.add_argument( + "--split-name", + type=str, + default="wenetspeech4tts", + choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], + help="huggingface dataset split name", + ) + parser.add_argument( + "--output_dir", required=True, type=str, help="dir to save result" + ) + parser.add_argument( + "--batch_size", + required=True, + type=int, + help="batch size (per-device) for inference", + ) + parser.add_argument( + "--num_workers", type=int, default=4, help="workers for dataloader" + ) + parser.add_argument( + "--prefetch", type=int, default=5, help="prefetch for dataloader" + ) + parser.add_argument( + "--llm-model-name-or-path", + required=True, + type=str, + help="model version", + ) + parser.add_argument( + "--tokenizer-dir", + required=True, + type=str, + help="tokenizer dir", + ) + parser.add_argument( + "--vocoder-dir", + required=True, + type=str, + help="vocoder dir", + ) + parser.add_argument( + "--flow-matching-model-path", + required=True, + type=str, + help="flow matching model path", + ) + add_model_arguments(parser) + args = parser.parse_args() + return args + + +def padded_mel_batch(ref_mels): + max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() + padded_ref_mels = [] + for mel in ref_mels: + padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0) + padded_ref_mels.append(padded_ref_mel) + padded_ref_mels = torch.stack(padded_ref_mels) + padded_ref_mels = padded_ref_mels.permute(0, 2, 1) + return padded_ref_mels + + +def data_collator(batch, tokenizer, mel_spectrogram): + speech_generation_start_index = tokenizer.convert_tokens_to_ids( + "<|SPEECH_GENERATION_START|>" + ) + assistant_index = tokenizer.convert_tokens_to_ids("assistant") + target_sample_rate = 24000 + hop_length = 256 + target_rms = 0.1 + input_ids_list, ref_mel_list, ref_mel_len_list = [], [], [] + for i, item in enumerate(batch): + prompt_text, target_text, prompt_audio_codes = ( + item["prompt_text"], + item["target_text"], + item["prompt_audio_cosy2_tokens"], + ) + message = [ + { + "role": "user", + "content": f"Convert the text to speech: {prompt_text + target_text}", + }, + {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}, + ] + + input_ids = tokenizer.apply_chat_template( + message, + tokenize=True, + chat_template=TEMPLATE, + ) + + prompt_audio_codes = [c + 151665 for c in prompt_audio_codes] + + idx = input_ids.index(speech_generation_start_index) + input_ids = input_ids[:idx] + prompt_audio_codes + input_ids_list.append(input_ids) + + # get flow matching model's prompt mel spectrogram + ref_audio_org, ref_sr = ( + item["prompt_audio"]["array"], + item["prompt_audio"]["sampling_rate"], + ) + ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float() + ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) + if ref_rms < target_rms: + ref_audio_org = ref_audio_org * target_rms / ref_rms + + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio_org) + else: + ref_audio = ref_audio_org + + # Duration in mel frame length + ref_mel_len = ref_audio.shape[-1] // hop_length + # to mel spectrogram + ref_mel = mel_spectrogram(ref_audio) + ref_mel = ref_mel.squeeze(0) + + ref_mel_list.append(ref_mel) + ref_mel_len_list.append(ref_mel_len) + + max_len = max([len(input_ids) for input_ids in input_ids_list]) + input_ids_list = [ + [tokenizer.pad_token_id] * (max_len - len(input_ids)) + input_ids + for input_ids in input_ids_list + ] + input_ids = torch.tensor(input_ids_list, dtype=torch.int64) + attention_mask = input_ids.ne(tokenizer.pad_token_id).long() + ids = [item["id"] for item in batch] + + ref_mel_batch = padded_mel_batch(ref_mel_list) + ref_mel_len_batch = torch.LongTensor(ref_mel_len_list) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "ids": ids, + "ref_mel_batch": ref_mel_batch, + "ref_mel_len_batch": ref_mel_len_batch, + } + + +def init_distributed(): + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + rank = int(os.environ.get("RANK", 0)) + print( + "Inference on multiple gpus, this gpu {}".format(local_rank) + + ", rank {}, world_size {}".format(rank, world_size) + ) + torch.cuda.set_device(local_rank) + dist.init_process_group("nccl") + return world_size, local_rank, rank + + +def main(): + args = get_args() + os.makedirs(args.output_dir, exist_ok=True) + + assert torch.cuda.is_available() + world_size, local_rank, rank = init_distributed() + device = torch.device(f"cuda:{local_rank}") + model = LLMTTS( + model_dir=args.llm_model_name_or_path, + tokenizer_dir=args.tokenizer_dir, + s3_tokenizer_name=args.s3_tokenizer_name, + device=device, + ) + + vocoder = BigVGANInference.from_pretrained(args.vocoder_dir, use_cuda_kernel=False) + vocoder = vocoder.eval().to(device) + + flow_matching_model = get_model(args).eval().to(device) + _ = load_checkpoint( + args.flow_matching_model_path, + model=flow_matching_model, + ) + + dataset = load_dataset( + "yuekai/seed_tts_cosy2", + split=args.split_name, + trust_remote_code=True, + ) + + sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) + + mel_spectrogram = MelSpec( + n_fft=1024, + hop_length=256, + win_length=1024, + n_mel_channels=100, + target_sample_rate=24000, + mel_spec_type="bigvgan", + ) + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + sampler=sampler, + shuffle=False, + num_workers=args.num_workers, + prefetch_factor=args.prefetch, + collate_fn=lambda x: data_collator(x, model.tokenizer, mel_spectrogram), + ) + + total_steps = len(dataset) + + if rank == 0: + progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs") + + for batch in dataloader: + generate_codes = model.inference_batch( + batch["input_ids"], batch["attention_mask"] + ) + flow_matching_input_tokens, total_mel_lens = [], [] + for i, code in enumerate(generate_codes): + flow_matching_input_token = interpolate_tokens(code) + total_mel_len = len(flow_matching_input_token) + flow_matching_input_tokens.append(flow_matching_input_token) + total_mel_lens.append(total_mel_len) + total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) + ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch[ + "ref_mel_len_batch" + ].to(device) + + max_len = max([len(tokens) for tokens in flow_matching_input_tokens]) + # pad tokens to the same length + for i, tokens in enumerate(flow_matching_input_tokens): + flow_matching_input_tokens[i] = torch.tensor( + tokens + [-1] * (max_len - len(tokens)), dtype=torch.long + ) + flow_matching_input_tokens = torch.stack(flow_matching_input_tokens).to(device) + generated, _ = flow_matching_model.sample( + cond=ref_mels, + text=flow_matching_input_tokens, + duration=total_mel_lens, + lens=ref_mel_lens, + steps=16, + cfg_strength=2.0, + sway_sampling_coef=-1, + no_ref_audio=False, + seed=0, + ) + + for i, gen in enumerate(generated): + gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) + gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) + + generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() + target_rms = 0.1 + target_sample_rate = 24_000 + # if ref_rms_list[i] < target_rms: + # generated_wave = generated_wave * ref_rms_list[i] / target_rms + utt = batch["ids"][i] + torchaudio.save( + f"{args.output_dir}/{utt}.wav", + generated_wave, + target_sample_rate, + ) + + if rank == 0: + progress_bar.update(world_size * len(batch["ids"])) + + if rank == 0: + progress_bar.close() + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py b/egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py new file mode 100644 index 000000000..bf878db51 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py @@ -0,0 +1,108 @@ +# Copyright (c) 2025 SparkAudio +# 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# https://github.com/SparkAudio/Spark-TTS/blob/main/cli/SparkTTS.py + +import re +from pathlib import Path +from typing import List + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +class LLMTTS: + """ + LLM-TTS for text-to-speech generation. + """ + + def __init__( + self, + model_dir: Path, + tokenizer_dir: Path, + s3_tokenizer_name: str, + device: torch.device, + ): + """ + Initializes the LLMTTS model with the provided configurations and device. + + Args: + model_dir (Path): Directory containing the model and config files. + device (torch.device): The device (CPU/GPU) to run the model on. + """ + self.device = device + + self.model = AutoModelForCausalLM.from_pretrained( + model_dir, + torch_dtype=torch.float16, + device_map=device, + attn_implementation="flash_attention_2", + ) + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) + new_tokens = [f"<|s_{i}|>" for i in range(6561)] + [ + "<|SPEECH_GENERATION_START|>" + ] + num_added_tokens = tokenizer.add_tokens(new_tokens) + tokenizer.padding_side = "left" + self.tokenizer = tokenizer + self.assistant_index = tokenizer.convert_tokens_to_ids("assistant") + + @torch.no_grad() + def inference_batch( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + temperature: float = 0.8, + top_k: float = 50, + top_p: float = 0.95, + ) -> torch.Tensor: + """ + Performs inference to generate speech from text, incorporating prompt audio and/or text. + + Args: + text (str): The text input to be converted to speech. + prompt_speech_path (Path): Path to the audio file used as a prompt. + prompt_text (str, optional): Transcript of the prompt audio. + gender (str): female | male. + pitch (str): very_low | low | moderate | high | very_high + speed (str): very_low | low | moderate | high | very_high + temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8. + top_k (float, optional): Top-k sampling parameter. Default is 50. + top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95. + + Returns: + torch.Tensor: Generated waveform as a tensor. + """ + # Generate speech using the model + generated_ids = self.model.generate( + input_ids=input_ids.to(self.device), + attention_mask=attention_mask.to(self.device), + max_new_tokens=1024, + do_sample=True, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ) + + results = [] + generated_ids = generated_ids.cpu().tolist() + for i in range(len(generated_ids)): + assistant_index = generated_ids[i].index(self.assistant_index) + padding_index = len(generated_ids[i]) + result = generated_ids[i][assistant_index + 2 :] + result = [token - 151665 for token in result] + result = [token for token in result if token >= 0] + results.append(result) + return results