diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer_llasa.py b/egs/wenetspeech4tts/TTS/f5-tts/infer_llasa.py new file mode 100644 index 000000000..6964a43be --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/infer_llasa.py @@ -0,0 +1,828 @@ +#!/usr/bin/env python3 +# Modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/eval/eval_infer_batch.py +""" +Usage: +# docker: ghcr.io/swivid/f5-tts:main +# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html +# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece sherpa-onnx +# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x +manifest=/path/seed_tts_eval/seedtts_testset/zh/meta.lst +python3 f5-tts/generate_averaged_model.py \ + --epoch 56 \ + --avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ + --exp-dir exp/f5_small + +# command for text token input +accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 + +# command for cosyvoice semantic token input +split=test_zh # seed_tts_eval test_zh +accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True + +bash local/compute_wer.sh $output_dir $manifest +""" +import argparse +import logging +import math +import os +import random +import time +from pathlib import Path + +import datasets +import torch +import torch.nn.functional as F +import torchaudio +from accelerate import Accelerator +from bigvganinference import BigVGANInference +from model.cfm import CFM +from model.dit import DiT +from model.modules import MelSpec +from model.utils import convert_char_to_pinyin +from tqdm import tqdm +from train import ( + add_model_arguments, + get_model, + get_tokenizer, + interpolate_tokens, + load_F5_TTS_pretrained_checkpoint, +) + +from icefall.checkpoint import load_checkpoint +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + default="f5-tts/vocab.txt", + help="Path to the unique text tokens file", + ) + + parser.add_argument( + "--model-path", + type=str, + default="/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", + help="Path to the unique text tokens file", + ) + + parser.add_argument( + "--seed", + type=int, + default=0, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--nfe", + type=int, + default=16, + help="The number of steps for the neural ODE", + ) + + parser.add_argument( + "--manifest-file", + type=str, + default=None, + help="The manifest file in seed_tts_eval format", + ) + + parser.add_argument( + "--output-dir", + type=Path, + default="results", + help="The output directory to save the generated wavs", + ) + + parser.add_argument("-ss", "--swaysampling", default=-1, type=float) + + parser.add_argument( + "--interpolate-token", + type=str2bool, + default=True, + help="Interpolate semantic token to match mel frames for CosyVoice", + ) + + parser.add_argument( + "--use-cosyvoice-semantic-token", + type=str2bool, + default=False, + help="Whether to use cosyvoice semantic token to replace text token.", + ) + + parser.add_argument( + "--split-name", + type=str, + default="wenetspeech4tts", + choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], + help="huggingface dataset split name", + ) + + add_model_arguments(parser) + return parser.parse_args() + + +def get_inference_prompt( + metainfo, + speed=1.0, + tokenizer="pinyin", + polyphone=True, + target_sample_rate=24000, + n_fft=1024, + win_length=1024, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + num_buckets=200, + min_secs=3, + max_secs=40, +): + prompts_all = [] + + min_tokens = min_secs * target_sample_rate // hop_length + max_tokens = max_secs * target_sample_rate // hop_length + + batch_accum = [0] * num_buckets + utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( + [[] for _ in range(num_buckets)] for _ in range(6) + ) + + mel_spectrogram = MelSpec( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + mel_spec_type=mel_spec_type, + ) + + for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm( + metainfo, desc="Processing prompts..." + ): + # Audio + ref_audio, ref_sr = torchaudio.load(prompt_wav) + ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio))) + if ref_rms < target_rms: + ref_audio = ref_audio * target_rms / ref_rms + assert ( + ref_audio.shape[-1] > 5000 + ), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio) + + # Text + if len(prompt_text[-1].encode("utf-8")) == 1: + prompt_text = prompt_text + " " + text = [prompt_text + gt_text] + if tokenizer == "pinyin": + text_list = convert_char_to_pinyin(text, polyphone=polyphone) + else: + text_list = text + + # Duration, mel frame length + ref_mel_len = ref_audio.shape[-1] // hop_length + if use_truth_duration: + gt_audio, gt_sr = torchaudio.load(gt_wav) + if gt_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) + gt_audio = resampler(gt_audio) + total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) + + # # test vocoder resynthesis + # ref_audio = gt_audio + else: + ref_text_len = len(prompt_text.encode("utf-8")) + gen_text_len = len(gt_text.encode("utf-8")) + total_mel_len = ref_mel_len + int( + ref_mel_len / ref_text_len * gen_text_len / speed + ) + + # to mel spectrogram + ref_mel = mel_spectrogram(ref_audio) + ref_mel = ref_mel.squeeze(0) + + # deal with batch + assert infer_batch_size > 0, "infer_batch_size should be greater than 0." + assert ( + min_tokens <= total_mel_len <= max_tokens + ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + bucket_i = math.floor( + (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets + ) + + utts[bucket_i].append(utt) + ref_rms_list[bucket_i].append(ref_rms) + ref_mels[bucket_i].append(ref_mel) + ref_mel_lens[bucket_i].append(ref_mel_len) + total_mel_lens[bucket_i].append(total_mel_len) + final_text_list[bucket_i].extend(text_list) + + batch_accum[bucket_i] += total_mel_len + + if batch_accum[bucket_i] >= infer_batch_size: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + batch_accum[bucket_i] = 0 + ( + utts[bucket_i], + ref_rms_list[bucket_i], + ref_mels[bucket_i], + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) = ( + [], + [], + [], + [], + [], + [], + ) + + # add residual + for bucket_i, bucket_frames in enumerate(batch_accum): + if bucket_frames > 0: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + # not only leave easy work for last workers + random.seed(666) + random.shuffle(prompts_all) + + return prompts_all + + +def get_inference_prompt_cosy_voice_huggingface( + dataset, + speed=1.0, + tokenizer="pinyin", + polyphone=True, + target_sample_rate=24000, + n_fft=1024, + win_length=1024, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + num_buckets=200, + min_secs=3, + max_secs=40, + interpolate_token=False, +): + prompts_all = [] + + min_tokens = min_secs * target_sample_rate // hop_length + max_tokens = max_secs * target_sample_rate // hop_length + + batch_accum = [0] * num_buckets + utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( + [[] for _ in range(num_buckets)] for _ in range(6) + ) + + mel_spectrogram = MelSpec( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + mel_spec_type=mel_spec_type, + ) + + for i in range(len(dataset)): + utt = dataset[i]["id"] + ref_audio_org, ref_sr = ( + dataset[i]["prompt_audio"]["array"], + dataset[i]["prompt_audio"]["sampling_rate"], + ) + ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float() + audio_tokens = dataset[i]["target_audio_cosy2_tokens"] + prompt_audio_tokens = dataset[i]["prompt_audio_cosy2_tokens"] + + ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) + if ref_rms < target_rms: + ref_audio_org = ref_audio_org * target_rms / ref_rms + + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio_org) + else: + ref_audio = ref_audio_org + input_tokens = prompt_audio_tokens + audio_tokens + + if interpolate_token: + input_tokens = interpolate_tokens(input_tokens) + text_list = input_tokens + + # Duration, mel frame length + ref_mel_len = ref_audio.shape[-1] // hop_length + + total_mel_len = len(input_tokens) + if not interpolate_token: + total_mel_len = int(total_mel_len / 4 * 15) + + # to mel spectrogram + ref_mel = mel_spectrogram(ref_audio) + ref_mel = ref_mel.squeeze(0) + + # deal with batch + assert infer_batch_size > 0, "infer_batch_size should be greater than 0." + if total_mel_len > max_tokens: + print( + f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + ) + continue + assert ( + min_tokens <= total_mel_len <= max_tokens + ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + bucket_i = math.floor( + (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets + ) + + utts[bucket_i].append(utt) + ref_rms_list[bucket_i].append(ref_rms) + ref_mels[bucket_i].append(ref_mel) + ref_mel_lens[bucket_i].append(ref_mel_len) + total_mel_lens[bucket_i].append(total_mel_len) + # final_text_list[bucket_i].extend(text_list) + final_text_list[bucket_i].append(text_list) + + batch_accum[bucket_i] += total_mel_len + + if batch_accum[bucket_i] >= infer_batch_size: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + batch_accum[bucket_i] = 0 + ( + utts[bucket_i], + ref_rms_list[bucket_i], + ref_mels[bucket_i], + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) = ( + [], + [], + [], + [], + [], + [], + ) + + # add residual + for bucket_i, bucket_frames in enumerate(batch_accum): + if bucket_frames > 0: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + # not only leave easy work for last workers + random.seed(666) + random.shuffle(prompts_all) + + return prompts_all + + +def inference_speech_token( + cosyvoice, + tts_text, + prompt_text, + prompt_speech_16k, + stream=False, + speed=1.0, + text_frontend=True, +): + tokens = [] + prompt_text = cosyvoice.frontend.text_normalize( + prompt_text, split=False, text_frontend=text_frontend + ) + for i in cosyvoice.frontend.text_normalize( + tts_text, split=True, text_frontend=text_frontend + ): + + tts_text_token, tts_text_token_len = cosyvoice.frontend._extract_text_token(i) + ( + prompt_text_token, + prompt_text_token_len, + ) = cosyvoice.frontend._extract_text_token(prompt_text) + speech_token, speech_token_len = cosyvoice.frontend._extract_speech_token( + prompt_speech_16k + ) + + for i in cosyvoice.model.llm.inference( + text=tts_text_token.to(cosyvoice.model.device), + text_len=torch.tensor([tts_text_token.shape[1]], dtype=torch.int32).to( + cosyvoice.model.device + ), + prompt_text=prompt_text_token.to(cosyvoice.model.device), + prompt_text_len=torch.tensor( + [prompt_text_token.shape[1]], dtype=torch.int32 + ).to(cosyvoice.model.device), + prompt_speech_token=speech_token.to(cosyvoice.model.device), + prompt_speech_token_len=torch.tensor( + [speech_token.shape[1]], dtype=torch.int32 + ).to(cosyvoice.model.device), + embedding=None, + ): + tokens.append(i) + return tokens, speech_token + + +def get_inference_prompt_cosy_voice( + metainfo, + speed=1.0, + tokenizer="pinyin", + polyphone=True, + target_sample_rate=24000, + n_fft=1024, + win_length=1024, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + num_buckets=200, + min_secs=3, + max_secs=40, + interpolate_token=False, +): + + import sys + + # please change the path to the cosyvoice accordingly + sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") + sys.path.append("/workspace/CosyVoice") + from cosyvoice.cli.cosyvoice import CosyVoice2 + + # please download the cosyvoice model first + cosyvoice = CosyVoice2( + "/workspace/CosyVoice2-0.5B", load_jit=False, load_trt=False, fp16=False + ) + + prompts_all = [] + + min_tokens = min_secs * target_sample_rate // hop_length + max_tokens = max_secs * target_sample_rate // hop_length + + batch_accum = [0] * num_buckets + utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( + [[] for _ in range(num_buckets)] for _ in range(6) + ) + + mel_spectrogram = MelSpec( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + mel_spec_type=mel_spec_type, + ) + + for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm( + metainfo, desc="Processing prompts..." + ): + # Audio + ref_audio_org, ref_sr = torchaudio.load(prompt_wav) + + # cosy voice + if ref_sr != 16000: + resampler = torchaudio.transforms.Resample(ref_sr, 16000) + ref_audio_16k = resampler(ref_audio_org) + else: + ref_audio_16k = ref_audio_org + audio_tokens, prompt_audio_tokens = inference_speech_token( + cosyvoice, gt_text, prompt_text, ref_audio_16k, stream=False + ) + + ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) + if ref_rms < target_rms: + ref_audio_org = ref_audio_org * target_rms / ref_rms + assert ( + ref_audio_org.shape[-1] > 5000 + ), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio_org) + else: + ref_audio = ref_audio_org + + # Text + # if len(prompt_text[-1].encode("utf-8")) == 1: + # prompt_text = prompt_text + " " + # text = [prompt_text + gt_text] + # if tokenizer == "pinyin": + # text_list = convert_char_to_pinyin(text, polyphone=polyphone) + # else: + # text_list = text + + # concat two tensors: prompt audio tokens with audio tokens --> shape 1, prompt_audio_tokens + audio_tokens + # prompt_audio_tokens shape 1, prompt_audio_tokens + # audio_tokens shape 1, audio_tokens + prompt_audio_tokens = prompt_audio_tokens.squeeze().cpu().tolist() + input_tokens = prompt_audio_tokens + audio_tokens + + # convert it into a list + # input_tokens_list = input_tokens.squeeze().cpu().tolist() + if interpolate_token: + input_tokens = interpolate_tokens(input_tokens) + text_list = input_tokens + + # Duration, mel frame length + ref_mel_len = ref_audio.shape[-1] // hop_length + if use_truth_duration: + gt_audio, gt_sr = torchaudio.load(gt_wav) + if gt_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) + gt_audio = resampler(gt_audio) + total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) + + # # test vocoder resynthesis + # ref_audio = gt_audio + else: + ref_text_len = len(prompt_text.encode("utf-8")) + gen_text_len = len(gt_text.encode("utf-8")) + total_mel_len_compute = ref_mel_len + int( + ref_mel_len / ref_text_len * gen_text_len / speed + ) + total_mel_len = len(input_tokens) + if not interpolate_token: + total_mel_len = int(total_mel_len / 4 * 15) + print( + f"total_mel_len_compute: {total_mel_len_compute}, total_mel_len: {total_mel_len}" + ) + + # to mel spectrogram + ref_mel = mel_spectrogram(ref_audio) + ref_mel = ref_mel.squeeze(0) + + # deal with batch + assert infer_batch_size > 0, "infer_batch_size should be greater than 0." + assert ( + min_tokens <= total_mel_len <= max_tokens + ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + bucket_i = math.floor( + (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets + ) + + utts[bucket_i].append(utt) + ref_rms_list[bucket_i].append(ref_rms) + ref_mels[bucket_i].append(ref_mel) + ref_mel_lens[bucket_i].append(ref_mel_len) + total_mel_lens[bucket_i].append(total_mel_len) + # final_text_list[bucket_i].extend(text_list) + final_text_list[bucket_i].append(text_list) + + batch_accum[bucket_i] += total_mel_len + + if batch_accum[bucket_i] >= infer_batch_size: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + batch_accum[bucket_i] = 0 + ( + utts[bucket_i], + ref_rms_list[bucket_i], + ref_mels[bucket_i], + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) = ( + [], + [], + [], + [], + [], + [], + ) + + # add residual + for bucket_i, bucket_frames in enumerate(batch_accum): + if bucket_frames > 0: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + # not only leave easy work for last workers + random.seed(666) + random.shuffle(prompts_all) + + return prompts_all + + +def padded_mel_batch(ref_mels): + max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() + padded_ref_mels = [] + for mel in ref_mels: + padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0) + padded_ref_mels.append(padded_ref_mel) + padded_ref_mels = torch.stack(padded_ref_mels) + padded_ref_mels = padded_ref_mels.permute(0, 2, 1) + return padded_ref_mels + + +def get_seedtts_testset_metainfo(metalst): + f = open(metalst) + lines = f.readlines() + f.close() + metainfo = [] + for line in lines: + assert len(line.strip().split("|")) == 4 + utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") + utt = Path(utt).stem + gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav") + if not os.path.isabs(prompt_wav): + prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) + metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav)) + return metainfo + + +def main(): + args = get_parser() + + accelerator = Accelerator() + device = f"cuda:{accelerator.process_index}" + if args.manifest_file: + metainfo = get_seedtts_testset_metainfo(args.manifest_file) + if not args.use_cosyvoice_semantic_token: + prompts_all = get_inference_prompt( + metainfo, + speed=1.0, + tokenizer="pinyin", + target_sample_rate=24_000, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + ) + else: + prompts_all = get_inference_prompt_cosy_voice( + metainfo, + speed=1.0, + tokenizer="pinyin", + target_sample_rate=24_000, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + interpolate_token=args.interpolate_token, + ) + else: + assert args.use_cosyvoice_semantic_token + dataset = datasets.load_dataset( + "yuekai/seed_tts_cosy2", + split=args.split_name, + trust_remote_code=True, + ) + prompts_all = get_inference_prompt_cosy_voice_huggingface( + dataset, + speed=1.0, + tokenizer="pinyin", + target_sample_rate=24_000, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + interpolate_token=args.interpolate_token, + ) + + vocoder = BigVGANInference.from_pretrained( + "./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False + ) + vocoder = vocoder.eval().to(device) + + model = get_model(args).eval().to(device) + checkpoint = torch.load(args.model_path, map_location="cpu") + if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint: + model = load_F5_TTS_pretrained_checkpoint(model, args.model_path) + else: + _ = load_checkpoint( + args.model_path, + model=model, + ) + + os.makedirs(args.output_dir, exist_ok=True) + + accelerator.wait_for_everyone() + start = time.time() + + with accelerator.split_between_processes(prompts_all) as prompts: + for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process): + ( + utts, + ref_rms_list, + ref_mels, + ref_mel_lens, + total_mel_lens, + final_text_list, + ) = prompt + ref_mels = ref_mels.to(device) + ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device) + total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) + + if args.use_cosyvoice_semantic_token: + # concat final_text_list + max_len = max([len(tokens) for tokens in final_text_list]) + # pad tokens to the same length + for i, tokens in enumerate(final_text_list): + final_text_list[i] = torch.tensor( + tokens + [-1] * (max_len - len(tokens)), dtype=torch.long + ) + final_text_list = torch.stack(final_text_list).to(device) + + # Inference + with torch.inference_mode(): + generated, _ = model.sample( + cond=ref_mels, + text=final_text_list, + duration=total_mel_lens, + lens=ref_mel_lens, + steps=args.nfe, + cfg_strength=2.0, + sway_sampling_coef=args.swaysampling, + no_ref_audio=False, + seed=args.seed, + ) + for i, gen in enumerate(generated): + gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) + gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) + + generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() + target_rms = 0.1 + target_sample_rate = 24_000 + if ref_rms_list[i] < target_rms: + generated_wave = generated_wave * ref_rms_list[i] / target_rms + torchaudio.save( + f"{args.output_dir}/{utts[i]}.wav", + generated_wave, + target_sample_rate, + ) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + timediff = time.time() - start + print(f"Done batch inference in {timediff / 60 :.2f} minutes.") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main()