From ddecda0c9b216bef9c22b52d6f8b76d6f2ddf182 Mon Sep 17 00:00:00 2001 From: jinzr Date: Mon, 4 Dec 2023 09:58:00 +0800 Subject: [PATCH] Update infer.py --- egs/vctk/TTS/vits/infer.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/egs/vctk/TTS/vits/infer.py b/egs/vctk/TTS/vits/infer.py index 95153a533..17c08ad40 100755 --- a/egs/vctk/TTS/vits/infer.py +++ b/egs/vctk/TTS/vits/infer.py @@ -78,6 +78,7 @@ def get_parser(): def infer_dataset( dl: torch.utils.data.DataLoader, + subset: str, params: AttributeDict, model: nn.Module, tokenizer: Tokenizer, @@ -99,6 +100,7 @@ def infer_dataset( # Background worker save audios to disk. def _save_worker( + subset: str, batch_size: int, cut_ids: List[str], audio: torch.Tensor, @@ -108,12 +110,12 @@ def infer_dataset( ): for i in range(batch_size): torchaudio.save( - str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"), + str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"), audio[i : i + 1, : audio_lens[i]], sample_rate=params.sampling_rate, ) torchaudio.save( - str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"), + str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"), audio_pred[i : i + 1, : audio_lens_pred[i]], sample_rate=params.sampling_rate, ) @@ -165,6 +167,7 @@ def infer_dataset( futures.append( executor.submit( _save_worker, + subset, batch_size, cut_ids, audio, @@ -241,13 +244,25 @@ def main(): test_cuts = vctk.test_cuts() test_dl = vctk.test_dataloaders(test_cuts) - infer_dataset( - dl=test_dl, - params=params, - model=model, - tokenizer=tokenizer, - speaker_map=speaker_map, - ) + valid_cuts = vctk.valid_cuts() + valid_dl = vctk.valid_dataloaders(valid_cuts) + + infer_sets = {"test", test_dl, "valid", valid_dl} + + for subset, dl in infer_sets.items(): + save_wav_dir = params.res_dir / "wav" / subset + save_wav_dir.mkdir(parents=True, exist_ok=True) + + logging.info(f"Processing {subset} set, saving to {save_wav_dir}") + + infer_dataset( + dl=dl, + subset=subset, + params=params, + model=model, + tokenizer=tokenizer, + speaker_map=speaker_map, + ) logging.info(f"Wav files are saved to {params.save_wav_dir}") logging.info("Done!")