mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 15:14:18 +00:00
Update infer.py
This commit is contained in:
parent
cf7ad8131d
commit
b7efcbf154
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao,
|
||||||
|
# Zengrui Jin,)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -30,7 +31,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import Dict, List
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
@ -80,6 +81,7 @@ def infer_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
|
speaker_map: Dict[str, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
|
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
|
||||||
@ -139,13 +141,20 @@ def infer_dataset(
|
|||||||
tokens_lens = tokens_lens.to(device)
|
tokens_lens = tokens_lens.to(device)
|
||||||
# tensor of shape (B, T)
|
# tensor of shape (B, T)
|
||||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
||||||
|
speakers = (
|
||||||
|
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]])
|
||||||
|
.int()
|
||||||
|
.to(device)
|
||||||
|
)
|
||||||
|
|
||||||
audio = batch["audio"]
|
audio = batch["audio"]
|
||||||
audio_lens = batch["audio_lens"].tolist()
|
audio_lens = batch["audio_lens"].tolist()
|
||||||
cut_ids = [cut.id for cut in batch["cut"]]
|
cut_ids = [cut.id for cut in batch["cut"]]
|
||||||
|
|
||||||
audio_pred, _, durations = model.inference_batch(
|
audio_pred, _, durations = model.inference_batch(
|
||||||
text=tokens, text_lengths=tokens_lens
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
sids=speakers,
|
||||||
)
|
)
|
||||||
audio_pred = audio_pred.detach().cpu()
|
audio_pred = audio_pred.detach().cpu()
|
||||||
# convert to samples
|
# convert to samples
|
||||||
@ -206,6 +215,12 @@ def main():
|
|||||||
params.oov_id = tokenizer.oov_id
|
params.oov_id = tokenizer.oov_id
|
||||||
params.vocab_size = tokenizer.vocab_size
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
vctk = VctkTtsDataModule(args)
|
||||||
|
speaker_map = vctk.speakers()
|
||||||
|
params.num_spks = len(speaker_map)
|
||||||
|
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -223,18 +238,15 @@ def main():
|
|||||||
logging.info(f"Number of parameters in discriminator: {num_param_d}")
|
logging.info(f"Number of parameters in discriminator: {num_param_d}")
|
||||||
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
|
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
|
||||||
|
|
||||||
# we need cut ids to display recognition results.
|
test_cuts = vctk.test_cuts()
|
||||||
args.return_cuts = True
|
test_dl = vctk.test_dataloaders(test_cuts)
|
||||||
ljspeech = VctkTtsDataModule(args)
|
|
||||||
|
|
||||||
test_cuts = ljspeech.test_cuts()
|
|
||||||
test_dl = ljspeech.test_dataloaders(test_cuts)
|
|
||||||
|
|
||||||
infer_dataset(
|
infer_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
speaker_map=speaker_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info(f"Wav files are saved to {params.save_wav_dir}")
|
logging.info(f"Wav files are saved to {params.save_wav_dir}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user