Update infer.py

This commit is contained in:
jinzr 2023-12-01 00:08:23 +08:00
parent cf7ad8131d
commit b7efcbf154

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) # Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao,
# Zengrui Jin,)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -30,7 +31,7 @@ import argparse
import logging import logging
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path
from typing import List from typing import Dict, List
import k2 import k2
import torch import torch
@ -80,6 +81,7 @@ def infer_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
tokenizer: Tokenizer, tokenizer: Tokenizer,
speaker_map: Dict[str, int],
) -> None: ) -> None:
"""Decode dataset. """Decode dataset.
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
@ -139,13 +141,20 @@ def infer_dataset(
tokens_lens = tokens_lens.to(device) tokens_lens = tokens_lens.to(device)
# tensor of shape (B, T) # tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
speakers = (
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]])
.int()
.to(device)
)
audio = batch["audio"] audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist() audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]] cut_ids = [cut.id for cut in batch["cut"]]
audio_pred, _, durations = model.inference_batch( audio_pred, _, durations = model.inference_batch(
text=tokens, text_lengths=tokens_lens text=tokens,
text_lengths=tokens_lens,
sids=speakers,
) )
audio_pred = audio_pred.detach().cpu() audio_pred = audio_pred.detach().cpu()
# convert to samples # convert to samples
@ -206,6 +215,12 @@ def main():
params.oov_id = tokenizer.oov_id params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size params.vocab_size = tokenizer.vocab_size
# we need cut ids to display recognition results.
args.return_cuts = True
vctk = VctkTtsDataModule(args)
speaker_map = vctk.speakers()
params.num_spks = len(speaker_map)
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
logging.info(params) logging.info(params)
@ -223,18 +238,15 @@ def main():
logging.info(f"Number of parameters in discriminator: {num_param_d}") logging.info(f"Number of parameters in discriminator: {num_param_d}")
logging.info(f"Total number of parameters: {num_param_g + num_param_d}") logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
# we need cut ids to display recognition results. test_cuts = vctk.test_cuts()
args.return_cuts = True test_dl = vctk.test_dataloaders(test_cuts)
ljspeech = VctkTtsDataModule(args)
test_cuts = ljspeech.test_cuts()
test_dl = ljspeech.test_dataloaders(test_cuts)
infer_dataset( infer_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
speaker_map=speaker_map,
) )
logging.info(f"Wav files are saved to {params.save_wav_dir}") logging.info(f"Wav files are saved to {params.save_wav_dir}")