diff --git a/egs/vctk/TTS/vits/infer.py b/egs/vctk/TTS/vits/infer.py index 42492da79..95153a533 100755 --- a/egs/vctk/TTS/vits/infer.py +++ b/egs/vctk/TTS/vits/infer.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao, +# Zengrui Jin,) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -30,7 +31,7 @@ import argparse import logging from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import List +from typing import Dict, List import k2 import torch @@ -80,6 +81,7 @@ def infer_dataset( params: AttributeDict, model: nn.Module, tokenizer: Tokenizer, + speaker_map: Dict[str, int], ) -> None: """Decode dataset. The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. @@ -139,13 +141,20 @@ def infer_dataset( tokens_lens = tokens_lens.to(device) # tensor of shape (B, T) tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + speakers = ( + torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]) + .int() + .to(device) + ) audio = batch["audio"] audio_lens = batch["audio_lens"].tolist() cut_ids = [cut.id for cut in batch["cut"]] audio_pred, _, durations = model.inference_batch( - text=tokens, text_lengths=tokens_lens + text=tokens, + text_lengths=tokens_lens, + sids=speakers, ) audio_pred = audio_pred.detach().cpu() # convert to samples @@ -206,6 +215,12 @@ def main(): params.oov_id = tokenizer.oov_id params.vocab_size = tokenizer.vocab_size + # we need cut ids to display recognition results. + args.return_cuts = True + vctk = VctkTtsDataModule(args) + speaker_map = vctk.speakers() + params.num_spks = len(speaker_map) + logging.info(f"Device: {device}") logging.info(params) @@ -223,18 +238,15 @@ def main(): logging.info(f"Number of parameters in discriminator: {num_param_d}") logging.info(f"Total number of parameters: {num_param_g + num_param_d}") - # we need cut ids to display recognition results. - args.return_cuts = True - ljspeech = VctkTtsDataModule(args) - - test_cuts = ljspeech.test_cuts() - test_dl = ljspeech.test_dataloaders(test_cuts) + test_cuts = vctk.test_cuts() + test_dl = vctk.test_dataloaders(test_cuts) infer_dataset( dl=test_dl, params=params, model=model, tokenizer=tokenizer, + speaker_map=speaker_map, ) logging.info(f"Wav files are saved to {params.save_wav_dir}")