Update infer.py

This commit is contained in:
jinzr 2023-12-04 09:58:00 +08:00
parent 52fe6f9bfd
commit ddecda0c9b

View File

@ -78,6 +78,7 @@ def get_parser():
def infer_dataset( def infer_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
subset: str,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
tokenizer: Tokenizer, tokenizer: Tokenizer,
@ -99,6 +100,7 @@ def infer_dataset(
# Background worker save audios to disk. # Background worker save audios to disk.
def _save_worker( def _save_worker(
subset: str,
batch_size: int, batch_size: int,
cut_ids: List[str], cut_ids: List[str],
audio: torch.Tensor, audio: torch.Tensor,
@ -108,12 +110,12 @@ def infer_dataset(
): ):
for i in range(batch_size): for i in range(batch_size):
torchaudio.save( torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"), str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"),
audio[i : i + 1, : audio_lens[i]], audio[i : i + 1, : audio_lens[i]],
sample_rate=params.sampling_rate, sample_rate=params.sampling_rate,
) )
torchaudio.save( torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"), str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"),
audio_pred[i : i + 1, : audio_lens_pred[i]], audio_pred[i : i + 1, : audio_lens_pred[i]],
sample_rate=params.sampling_rate, sample_rate=params.sampling_rate,
) )
@ -165,6 +167,7 @@ def infer_dataset(
futures.append( futures.append(
executor.submit( executor.submit(
_save_worker, _save_worker,
subset,
batch_size, batch_size,
cut_ids, cut_ids,
audio, audio,
@ -241,13 +244,25 @@ def main():
test_cuts = vctk.test_cuts() test_cuts = vctk.test_cuts()
test_dl = vctk.test_dataloaders(test_cuts) test_dl = vctk.test_dataloaders(test_cuts)
infer_dataset( valid_cuts = vctk.valid_cuts()
dl=test_dl, valid_dl = vctk.valid_dataloaders(valid_cuts)
params=params,
model=model, infer_sets = {"test", test_dl, "valid", valid_dl}
tokenizer=tokenizer,
speaker_map=speaker_map, for subset, dl in infer_sets.items():
) save_wav_dir = params.res_dir / "wav" / subset
save_wav_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Processing {subset} set, saving to {save_wav_dir}")
infer_dataset(
dl=dl,
subset=subset,
params=params,
model=model,
tokenizer=tokenizer,
speaker_map=speaker_map,
)
logging.info(f"Wav files are saved to {params.save_wav_dir}") logging.info(f"Wav files are saved to {params.save_wav_dir}")
logging.info("Done!") logging.info("Done!")