icefall/egs/vctk/TTS/vits/infer.py
zr_jin 735fb9a73d
A TTS recipe VITS on VCTK dataset (#1380)
* init

* isort formatted

* minor updates

* Create shared

* Update prepare_tokens_vctk.py

* Update prepare_tokens_vctk.py

* Update prepare_tokens_vctk.py

* Update prepare.sh

* updated

* Update train.py

* Update train.py

* Update tts_datamodule.py

* Update train.py

* Update train.py

* Update train.py

* Update train.py

* Update train.py

* Update train.py

* fixed formatting issue

* Update infer.py

* removed redundant files

* Create monotonic_align

* removed redundant files

* created symlinks

* Update prepare.sh

* minor adjustments

* Create requirements_tts.txt

* Update requirements_tts.txt

added version constraints

* Update infer.py

* Update infer.py

* Update infer.py

* updated docs

* Update export-onnx.py

* Update export-onnx.py

* Update test_onnx.py

* updated requirements.txt

* Update test_onnx.py

* Update test_onnx.py

* docs updated

* docs fixed

* minor updates
2023-12-06 09:59:19 +08:00

273 lines
7.7 KiB
Python
Executable File

#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script performs model inference on test set.
Usage:
./vits/infer.py \
--epoch 1000 \
--exp-dir ./vits/exp \
--max-duration 500
"""
import argparse
import logging
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List
import k2
import torch
import torch.nn as nn
import torchaudio
from tokenizer import Tokenizer
from train import get_model, get_params
from tts_datamodule import VctkTtsDataModule
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=1000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="vits/exp",
help="The experiment dir",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
return parser
def infer_dataset(
dl: torch.utils.data.DataLoader,
subset: str,
params: AttributeDict,
model: nn.Module,
tokenizer: Tokenizer,
speaker_map: Dict[str, int],
) -> None:
"""Decode dataset.
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
tokenizer:
Used to convert text to phonemes.
"""
# Background worker save audios to disk.
def _save_worker(
subset: str,
batch_size: int,
cut_ids: List[str],
audio: torch.Tensor,
audio_pred: torch.Tensor,
audio_lens: List[int],
audio_lens_pred: List[int],
):
for i in range(batch_size):
torchaudio.save(
str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"),
audio[i : i + 1, : audio_lens[i]],
sample_rate=params.sampling_rate,
)
torchaudio.save(
str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"),
audio_pred[i : i + 1, : audio_lens_pred[i]],
sample_rate=params.sampling_rate,
)
device = next(model.parameters()).device
num_cuts = 0
log_interval = 5
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
futures = []
with ThreadPoolExecutor(max_workers=1) as executor:
for batch_idx, batch in enumerate(dl):
batch_size = len(batch["tokens"])
tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
speakers = (
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]])
.int()
.to(device)
)
audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]
audio_pred, _, durations = model.inference_batch(
text=tokens,
text_lengths=tokens_lens,
sids=speakers,
)
audio_pred = audio_pred.detach().cpu()
# convert to samples
audio_lens_pred = (
(durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
)
futures.append(
executor.submit(
_save_worker,
subset,
batch_size,
cut_ids,
audio,
audio_pred,
audio_lens,
audio_lens_pred,
)
)
num_cuts += batch_size
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
# return results
for f in futures:
f.result()
@torch.no_grad()
def main():
parser = get_parser()
VctkTtsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.suffix = f"epoch-{params.epoch}"
params.res_dir = params.exp_dir / "infer" / params.suffix
params.save_wav_dir = params.res_dir / "wav"
params.save_wav_dir.mkdir(parents=True, exist_ok=True)
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
logging.info("Infer started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size
# we need cut ids to display recognition results.
args.return_cuts = True
vctk = VctkTtsDataModule(args)
speaker_map = vctk.speakers()
params.num_spks = len(speaker_map)
logging.info(f"Device: {device}")
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval()
num_param_g = sum([p.numel() for p in model.generator.parameters()])
logging.info(f"Number of parameters in generator: {num_param_g}")
num_param_d = sum([p.numel() for p in model.discriminator.parameters()])
logging.info(f"Number of parameters in discriminator: {num_param_d}")
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
test_cuts = vctk.test_cuts()
test_dl = vctk.test_dataloaders(test_cuts)
valid_cuts = vctk.valid_cuts()
valid_dl = vctk.valid_dataloaders(valid_cuts)
infer_sets = {"test": test_dl, "valid": valid_dl}
for subset, dl in infer_sets.items():
save_wav_dir = params.res_dir / "wav" / subset
save_wav_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Processing {subset} set, saving to {save_wav_dir}")
infer_dataset(
dl=dl,
subset=subset,
params=params,
model=model,
tokenizer=tokenizer,
speaker_map=speaker_map,
)
logging.info(f"Wav files are saved to {params.save_wav_dir}")
logging.info("Done!")
if __name__ == "__main__":
main()