mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
348 lines
11 KiB
Python
348 lines
11 KiB
Python
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" Example Usage
|
|
cpu:
|
|
|
|
s3tokenizer --data_dir xxx.scp \
|
|
--device "cpu" \
|
|
--output_dir "./" \
|
|
--batch_size 32
|
|
|
|
gpu:
|
|
|
|
torchrun --nproc_per_node=8 --nnodes=1 \
|
|
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
|
`which s3tokenizer` --data_dir xxx.scp \
|
|
--device "cuda" \
|
|
--output_dir "./" \
|
|
--batch_size 32
|
|
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import s3tokenizer
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn.functional as F
|
|
import torchaudio
|
|
from bigvganinference import BigVGANInference
|
|
from datasets import load_dataset
|
|
from lhotse.serialization import load_jsonl
|
|
from llm_tts import LLMTTS
|
|
from model.modules import MelSpec
|
|
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
|
from tqdm import tqdm
|
|
from train import (
|
|
add_model_arguments,
|
|
get_model,
|
|
get_tokenizer,
|
|
interpolate_tokens,
|
|
load_F5_TTS_pretrained_checkpoint,
|
|
)
|
|
|
|
from icefall.checkpoint import load_checkpoint
|
|
|
|
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser(description="extract speech code")
|
|
parser.add_argument(
|
|
"--s3-tokenizer-name",
|
|
required=False,
|
|
type=str,
|
|
choices=[
|
|
"speech_tokenizer_v1",
|
|
"speech_tokenizer_v1_25hz",
|
|
"speech_tokenizer_v2_25hz",
|
|
],
|
|
help="model version",
|
|
)
|
|
parser.add_argument(
|
|
"--split-name",
|
|
type=str,
|
|
default="wenetspeech4tts",
|
|
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
|
help="huggingface dataset split name",
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir", required=True, type=str, help="dir to save result"
|
|
)
|
|
parser.add_argument(
|
|
"--batch_size",
|
|
required=True,
|
|
type=int,
|
|
help="batch size (per-device) for inference",
|
|
)
|
|
parser.add_argument(
|
|
"--num_workers", type=int, default=4, help="workers for dataloader"
|
|
)
|
|
parser.add_argument(
|
|
"--prefetch", type=int, default=5, help="prefetch for dataloader"
|
|
)
|
|
parser.add_argument(
|
|
"--llm-model-name-or-path",
|
|
required=True,
|
|
type=str,
|
|
help="model version",
|
|
)
|
|
parser.add_argument(
|
|
"--tokenizer-dir",
|
|
required=True,
|
|
type=str,
|
|
help="tokenizer dir",
|
|
)
|
|
parser.add_argument(
|
|
"--vocoder-dir",
|
|
required=True,
|
|
type=str,
|
|
help="vocoder dir",
|
|
)
|
|
parser.add_argument(
|
|
"--flow-matching-model-path",
|
|
required=True,
|
|
type=str,
|
|
help="flow matching model path",
|
|
)
|
|
add_model_arguments(parser)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def padded_mel_batch(ref_mels):
|
|
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
|
padded_ref_mels = []
|
|
for mel in ref_mels:
|
|
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
|
|
padded_ref_mels.append(padded_ref_mel)
|
|
padded_ref_mels = torch.stack(padded_ref_mels)
|
|
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
|
|
return padded_ref_mels
|
|
|
|
|
|
def data_collator(batch, tokenizer, mel_spectrogram):
|
|
speech_generation_start_index = tokenizer.convert_tokens_to_ids(
|
|
"<|SPEECH_GENERATION_START|>"
|
|
)
|
|
assistant_index = tokenizer.convert_tokens_to_ids("assistant")
|
|
target_sample_rate = 24000
|
|
hop_length = 256
|
|
target_rms = 0.1
|
|
input_ids_list, ref_mel_list, ref_mel_len_list = [], [], []
|
|
for i, item in enumerate(batch):
|
|
prompt_text, target_text, prompt_audio_codes = (
|
|
item["prompt_text"],
|
|
item["target_text"],
|
|
item["prompt_audio_cosy2_tokens"],
|
|
)
|
|
message = [
|
|
{
|
|
"role": "user",
|
|
"content": f"Convert the text to speech: {prompt_text + target_text}",
|
|
},
|
|
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"},
|
|
]
|
|
|
|
input_ids = tokenizer.apply_chat_template(
|
|
message,
|
|
tokenize=True,
|
|
chat_template=TEMPLATE,
|
|
)
|
|
|
|
prompt_audio_codes = [c + 151665 for c in prompt_audio_codes]
|
|
|
|
idx = input_ids.index(speech_generation_start_index)
|
|
input_ids = input_ids[:idx] + prompt_audio_codes
|
|
input_ids_list.append(input_ids)
|
|
|
|
# get flow matching model's prompt mel spectrogram
|
|
ref_audio_org, ref_sr = (
|
|
item["prompt_audio"]["array"],
|
|
item["prompt_audio"]["sampling_rate"],
|
|
)
|
|
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
|
|
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
|
if ref_rms < target_rms:
|
|
ref_audio_org = ref_audio_org * target_rms / ref_rms
|
|
|
|
if ref_sr != target_sample_rate:
|
|
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
|
ref_audio = resampler(ref_audio_org)
|
|
else:
|
|
ref_audio = ref_audio_org
|
|
|
|
# Duration in mel frame length
|
|
ref_mel_len = ref_audio.shape[-1] // hop_length
|
|
# to mel spectrogram
|
|
ref_mel = mel_spectrogram(ref_audio)
|
|
ref_mel = ref_mel.squeeze(0)
|
|
|
|
ref_mel_list.append(ref_mel)
|
|
ref_mel_len_list.append(ref_mel_len)
|
|
|
|
max_len = max([len(input_ids) for input_ids in input_ids_list])
|
|
input_ids_list = [
|
|
[tokenizer.pad_token_id] * (max_len - len(input_ids)) + input_ids
|
|
for input_ids in input_ids_list
|
|
]
|
|
input_ids = torch.tensor(input_ids_list, dtype=torch.int64)
|
|
attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
|
|
ids = [item["id"] for item in batch]
|
|
|
|
ref_mel_batch = padded_mel_batch(ref_mel_list)
|
|
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
|
|
return {
|
|
"input_ids": input_ids,
|
|
"attention_mask": attention_mask,
|
|
"ids": ids,
|
|
"ref_mel_batch": ref_mel_batch,
|
|
"ref_mel_len_batch": ref_mel_len_batch,
|
|
}
|
|
|
|
|
|
def init_distributed():
|
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
rank = int(os.environ.get("RANK", 0))
|
|
print(
|
|
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
|
+ ", rank {}, world_size {}".format(rank, world_size)
|
|
)
|
|
torch.cuda.set_device(local_rank)
|
|
dist.init_process_group("nccl")
|
|
return world_size, local_rank, rank
|
|
|
|
|
|
def main():
|
|
args = get_args()
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
assert torch.cuda.is_available()
|
|
world_size, local_rank, rank = init_distributed()
|
|
device = torch.device(f"cuda:{local_rank}")
|
|
model = LLMTTS(
|
|
model_dir=args.llm_model_name_or_path,
|
|
tokenizer_dir=args.tokenizer_dir,
|
|
s3_tokenizer_name=args.s3_tokenizer_name,
|
|
device=device,
|
|
)
|
|
|
|
vocoder = BigVGANInference.from_pretrained(args.vocoder_dir, use_cuda_kernel=False)
|
|
vocoder = vocoder.eval().to(device)
|
|
|
|
flow_matching_model = get_model(args).eval().to(device)
|
|
_ = load_checkpoint(
|
|
args.flow_matching_model_path,
|
|
model=flow_matching_model,
|
|
)
|
|
|
|
dataset = load_dataset(
|
|
"yuekai/seed_tts_cosy2",
|
|
split=args.split_name,
|
|
trust_remote_code=True,
|
|
)
|
|
|
|
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
|
|
|
mel_spectrogram = MelSpec(
|
|
n_fft=1024,
|
|
hop_length=256,
|
|
win_length=1024,
|
|
n_mel_channels=100,
|
|
target_sample_rate=24000,
|
|
mel_spec_type="bigvgan",
|
|
)
|
|
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=args.batch_size,
|
|
sampler=sampler,
|
|
shuffle=False,
|
|
num_workers=args.num_workers,
|
|
prefetch_factor=args.prefetch,
|
|
collate_fn=lambda x: data_collator(x, model.tokenizer, mel_spectrogram),
|
|
)
|
|
|
|
total_steps = len(dataset)
|
|
|
|
if rank == 0:
|
|
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
|
|
|
for batch in dataloader:
|
|
generate_codes = model.inference_batch(
|
|
batch["input_ids"], batch["attention_mask"]
|
|
)
|
|
flow_matching_input_tokens, total_mel_lens = [], []
|
|
for i, code in enumerate(generate_codes):
|
|
flow_matching_input_token = interpolate_tokens(code)
|
|
total_mel_len = len(flow_matching_input_token)
|
|
flow_matching_input_tokens.append(flow_matching_input_token)
|
|
total_mel_lens.append(total_mel_len)
|
|
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
|
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch[
|
|
"ref_mel_len_batch"
|
|
].to(device)
|
|
|
|
max_len = max([len(tokens) for tokens in flow_matching_input_tokens])
|
|
# pad tokens to the same length
|
|
for i, tokens in enumerate(flow_matching_input_tokens):
|
|
flow_matching_input_tokens[i] = torch.tensor(
|
|
tokens + [-1] * (max_len - len(tokens)), dtype=torch.long
|
|
)
|
|
flow_matching_input_tokens = torch.stack(flow_matching_input_tokens).to(device)
|
|
generated, _ = flow_matching_model.sample(
|
|
cond=ref_mels,
|
|
text=flow_matching_input_tokens,
|
|
duration=total_mel_lens,
|
|
lens=ref_mel_lens,
|
|
steps=16,
|
|
cfg_strength=2.0,
|
|
sway_sampling_coef=-1,
|
|
no_ref_audio=False,
|
|
seed=0,
|
|
)
|
|
|
|
for i, gen in enumerate(generated):
|
|
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
|
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
|
|
|
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
|
target_rms = 0.1
|
|
target_sample_rate = 24_000
|
|
# if ref_rms_list[i] < target_rms:
|
|
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
|
utt = batch["ids"][i]
|
|
torchaudio.save(
|
|
f"{args.output_dir}/{utt}.wav",
|
|
generated_wave,
|
|
target_sample_rate,
|
|
)
|
|
|
|
if rank == 0:
|
|
progress_bar.update(world_size * len(batch["ids"]))
|
|
|
|
if rank == 0:
|
|
progress_bar.close()
|
|
|
|
dist.barrier()
|
|
dist.destroy_process_group()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|