Update infer.py

This commit is contained in:
jinzr 2023-12-01 00:08:23 +08:00
parent cf7ad8131d
commit b7efcbf154

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -30,7 +31,7 @@ import argparse
import logging
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import List
from typing import Dict, List
import k2
import torch
@ -80,6 +81,7 @@ def infer_dataset(
params: AttributeDict,
model: nn.Module,
tokenizer: Tokenizer,
speaker_map: Dict[str, int],
) -> None:
"""Decode dataset.
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
@ -139,13 +141,20 @@ def infer_dataset(
tokens_lens = tokens_lens.to(device)
# tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
speakers = (
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]])
.int()
.to(device)
)
audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]
audio_pred, _, durations = model.inference_batch(
text=tokens, text_lengths=tokens_lens
text=tokens,
text_lengths=tokens_lens,
sids=speakers,
)
audio_pred = audio_pred.detach().cpu()
# convert to samples
@ -206,6 +215,12 @@ def main():
params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size
# we need cut ids to display recognition results.
args.return_cuts = True
vctk = VctkTtsDataModule(args)
speaker_map = vctk.speakers()
params.num_spks = len(speaker_map)
logging.info(f"Device: {device}")
logging.info(params)
@ -223,18 +238,15 @@ def main():
logging.info(f"Number of parameters in discriminator: {num_param_d}")
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
# we need cut ids to display recognition results.
args.return_cuts = True
ljspeech = VctkTtsDataModule(args)
test_cuts = ljspeech.test_cuts()
test_dl = ljspeech.test_dataloaders(test_cuts)
test_cuts = vctk.test_cuts()
test_dl = vctk.test_dataloaders(test_cuts)
infer_dataset(
dl=test_dl,
params=params,
model=model,
tokenizer=tokenizer,
speaker_map=speaker_map,
)
logging.info(f"Wav files are saved to {params.save_wav_dir}")