Update infer.py

This commit is contained in:
jinzr 2023-12-04 09:58:00 +08:00
parent 52fe6f9bfd
commit ddecda0c9b

View File

@ -78,6 +78,7 @@ def get_parser():
def infer_dataset(
dl: torch.utils.data.DataLoader,
subset: str,
params: AttributeDict,
model: nn.Module,
tokenizer: Tokenizer,
@ -99,6 +100,7 @@ def infer_dataset(
# Background worker save audios to disk.
def _save_worker(
subset: str,
batch_size: int,
cut_ids: List[str],
audio: torch.Tensor,
@ -108,12 +110,12 @@ def infer_dataset(
):
for i in range(batch_size):
torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"),
audio[i : i + 1, : audio_lens[i]],
sample_rate=params.sampling_rate,
)
torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"),
audio_pred[i : i + 1, : audio_lens_pred[i]],
sample_rate=params.sampling_rate,
)
@ -165,6 +167,7 @@ def infer_dataset(
futures.append(
executor.submit(
_save_worker,
subset,
batch_size,
cut_ids,
audio,
@ -241,13 +244,25 @@ def main():
test_cuts = vctk.test_cuts()
test_dl = vctk.test_dataloaders(test_cuts)
infer_dataset(
dl=test_dl,
params=params,
model=model,
tokenizer=tokenizer,
speaker_map=speaker_map,
)
valid_cuts = vctk.valid_cuts()
valid_dl = vctk.valid_dataloaders(valid_cuts)
infer_sets = {"test", test_dl, "valid", valid_dl}
for subset, dl in infer_sets.items():
save_wav_dir = params.res_dir / "wav" / subset
save_wav_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Processing {subset} set, saving to {save_wav_dir}")
infer_dataset(
dl=dl,
subset=subset,
params=params,
model=model,
tokenizer=tokenizer,
speaker_map=speaker_map,
)
logging.info(f"Wav files are saved to {params.save_wav_dir}")
logging.info("Done!")