Update infer.py

This commit is contained in:
jinzr 2023-12-04 09:58:00 +08:00
parent 52fe6f9bfd
commit ddecda0c9b

View File

@ -78,6 +78,7 @@ def get_parser():
def infer_dataset( def infer_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
subset: str,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
tokenizer: Tokenizer, tokenizer: Tokenizer,
@ -99,6 +100,7 @@ def infer_dataset(
# Background worker save audios to disk. # Background worker save audios to disk.
def _save_worker( def _save_worker(
subset: str,
batch_size: int, batch_size: int,
cut_ids: List[str], cut_ids: List[str],
audio: torch.Tensor, audio: torch.Tensor,
@ -108,12 +110,12 @@ def infer_dataset(
): ):
for i in range(batch_size): for i in range(batch_size):
torchaudio.save( torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"), str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"),
audio[i : i + 1, : audio_lens[i]], audio[i : i + 1, : audio_lens[i]],
sample_rate=params.sampling_rate, sample_rate=params.sampling_rate,
) )
torchaudio.save( torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"), str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"),
audio_pred[i : i + 1, : audio_lens_pred[i]], audio_pred[i : i + 1, : audio_lens_pred[i]],
sample_rate=params.sampling_rate, sample_rate=params.sampling_rate,
) )
@ -165,6 +167,7 @@ def infer_dataset(
futures.append( futures.append(
executor.submit( executor.submit(
_save_worker, _save_worker,
subset,
batch_size, batch_size,
cut_ids, cut_ids,
audio, audio,
@ -241,8 +244,20 @@ def main():
test_cuts = vctk.test_cuts() test_cuts = vctk.test_cuts()
test_dl = vctk.test_dataloaders(test_cuts) test_dl = vctk.test_dataloaders(test_cuts)
valid_cuts = vctk.valid_cuts()
valid_dl = vctk.valid_dataloaders(valid_cuts)
infer_sets = {"test", test_dl, "valid", valid_dl}
for subset, dl in infer_sets.items():
save_wav_dir = params.res_dir / "wav" / subset
save_wav_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Processing {subset} set, saving to {save_wav_dir}")
infer_dataset( infer_dataset(
dl=test_dl, dl=dl,
subset=subset,
params=params, params=params,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,