mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
Update infer.py
This commit is contained in:
parent
cf7ad8131d
commit
b7efcbf154
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao,
|
||||
# Zengrui Jin,)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -30,7 +31,7 @@ import argparse
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
@ -80,6 +81,7 @@ def infer_dataset(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
tokenizer: Tokenizer,
|
||||
speaker_map: Dict[str, int],
|
||||
) -> None:
|
||||
"""Decode dataset.
|
||||
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
|
||||
@ -139,13 +141,20 @@ def infer_dataset(
|
||||
tokens_lens = tokens_lens.to(device)
|
||||
# tensor of shape (B, T)
|
||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
||||
speakers = (
|
||||
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]])
|
||||
.int()
|
||||
.to(device)
|
||||
)
|
||||
|
||||
audio = batch["audio"]
|
||||
audio_lens = batch["audio_lens"].tolist()
|
||||
cut_ids = [cut.id for cut in batch["cut"]]
|
||||
|
||||
audio_pred, _, durations = model.inference_batch(
|
||||
text=tokens, text_lengths=tokens_lens
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
sids=speakers,
|
||||
)
|
||||
audio_pred = audio_pred.detach().cpu()
|
||||
# convert to samples
|
||||
@ -206,6 +215,12 @@ def main():
|
||||
params.oov_id = tokenizer.oov_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
vctk = VctkTtsDataModule(args)
|
||||
speaker_map = vctk.speakers()
|
||||
params.num_spks = len(speaker_map)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
logging.info(params)
|
||||
|
||||
@ -223,18 +238,15 @@ def main():
|
||||
logging.info(f"Number of parameters in discriminator: {num_param_d}")
|
||||
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
ljspeech = VctkTtsDataModule(args)
|
||||
|
||||
test_cuts = ljspeech.test_cuts()
|
||||
test_dl = ljspeech.test_dataloaders(test_cuts)
|
||||
test_cuts = vctk.test_cuts()
|
||||
test_dl = vctk.test_dataloaders(test_cuts)
|
||||
|
||||
infer_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
speaker_map=speaker_map,
|
||||
)
|
||||
|
||||
logging.info(f"Wav files are saved to {params.save_wav_dir}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user