2025-05-13 05:37:11 +00:00

257 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
# 2025 authors: Yuekai Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
""" Example Usage
split=test_zh
llm_path=f5-tts/exp_zh/checkpoint-805000
huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x
vocoder=./bigvgan_v2_24khz_100band_256x
torchrun --nproc_per_node=2 \
f5-tts/infer_dist.py \
--output_dir $output_dir \
--batch_size 1 \
--num_workers 2 \
--llm-model-name-or-path $llm_path \
--flow-matching-model-path $model_path \
--decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
--use-cosyvoice-semantic-token True \
--vocoder-dir $vocoder \
--split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \
--tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct
"""
import argparse
import json
import os
from pathlib import Path
import torch
import torch.distributed as dist
import torch.nn.functional as F
import whisper
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
from transformers import AutoTokenizer
from web_demo import get_model
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
# https://github.com/FunAudioLLM/CosyVoice/tree/main/third_party
# sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
def get_args():
parser = argparse.ArgumentParser(description="extract speech code")
parser.add_argument(
"--split-name",
type=str,
default="test",
help="huggingface dataset split name",
)
parser.add_argument(
"--subset-name",
type=str,
default="commoneval",
help="subset name",
)
parser.add_argument(
"--output-dir", required=True, type=str, help="dir to save result"
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="batch size (per-device) for inference",
)
parser.add_argument(
"--num-workers", type=int, default=2, help="workers for dataloader"
)
parser.add_argument(
"--prefetch", type=int, default=2, help="prefetch for dataloader"
)
parser.add_argument(
"--checkpoint-path",
type=str,
default=None,
help="Checkpoint name or path, default to %(default)r",
)
# parser.add_argument(
# "--top-k",
# type=int,
# default=50,
# help="top k for sampling",
# )
# parser.add_argument(
# "--top-p",
# type=float,
# default=0.95,
# help="top p for sampling",
# )
# parser.add_argument(
# "--temperature",
# type=float,
# default=0.8,
# help="temperature for sampling",
# )
add_model_arguments(parser)
args = parser.parse_args()
return args
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
dist.init_process_group("nccl")
return world_size, local_rank, rank
def preprocess(
messages,
tokenizer,
):
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
add_generation_prompt=False,
chat_template=TEMPLATE,
padding="longest",
truncation=False,
)
)
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask
def custom_collate(batch):
assert len(batch) == 1
audio = batch[0]["audio"]
assert audio["sampling_rate"] == 16000
result = {"audio": audio["array"]}
for keys in batch[0].keys():
if keys != "audio":
result[keys] = batch[0][keys]
return result
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}")
dataset = load_dataset(
"hlt-lab/voicebench",
args.subset_name,
split=args.split_name,
trust_remote_code=True,
)
model, tokenizer = get_model(args)
# tokenizer = AutoTokenizer.from_pretrained(args.llm_path_or_name)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=custom_collate,
)
total_steps = len(dataset)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": ""},
]
input_ids, attention_mask = preprocess([message], tokenizer)
results_jsonl_file = open(
os.path.join(
args.output_dir,
f"results-{args.subset_name}-{args.split_name}-{rank}-audio.jsonl",
),
"w",
)
for batch in dataloader:
audio = batch["audio"]
audio = torch.from_numpy(audio).to(device).to(torch.float32)
fbank = whisper.log_mel_spectrogram(audio, device=device)
fbank = fbank.unsqueeze(0)
generated_ids = model.decode(
fbank, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
)
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
result_dict = {}
for key in batch.keys():
if key != "audio":
result_dict[key] = batch[key]
result_dict["response"] = hyps[0]
json.dump(result_dict, results_jsonl_file)
results_jsonl_file.write("\n")
if rank == 0:
progress_bar.update(world_size * args.batch_size)
if rank == 0:
progress_bar.close()
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()