mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
Update infer.py
This commit is contained in:
parent
52fe6f9bfd
commit
ddecda0c9b
@ -78,6 +78,7 @@ def get_parser():
|
|||||||
|
|
||||||
def infer_dataset(
|
def infer_dataset(
|
||||||
dl: torch.utils.data.DataLoader,
|
dl: torch.utils.data.DataLoader,
|
||||||
|
subset: str,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
@ -99,6 +100,7 @@ def infer_dataset(
|
|||||||
|
|
||||||
# Background worker save audios to disk.
|
# Background worker save audios to disk.
|
||||||
def _save_worker(
|
def _save_worker(
|
||||||
|
subset: str,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
cut_ids: List[str],
|
cut_ids: List[str],
|
||||||
audio: torch.Tensor,
|
audio: torch.Tensor,
|
||||||
@ -108,12 +110,12 @@ def infer_dataset(
|
|||||||
):
|
):
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
torchaudio.save(
|
torchaudio.save(
|
||||||
str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
|
str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"),
|
||||||
audio[i : i + 1, : audio_lens[i]],
|
audio[i : i + 1, : audio_lens[i]],
|
||||||
sample_rate=params.sampling_rate,
|
sample_rate=params.sampling_rate,
|
||||||
)
|
)
|
||||||
torchaudio.save(
|
torchaudio.save(
|
||||||
str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
|
str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"),
|
||||||
audio_pred[i : i + 1, : audio_lens_pred[i]],
|
audio_pred[i : i + 1, : audio_lens_pred[i]],
|
||||||
sample_rate=params.sampling_rate,
|
sample_rate=params.sampling_rate,
|
||||||
)
|
)
|
||||||
@ -165,6 +167,7 @@ def infer_dataset(
|
|||||||
futures.append(
|
futures.append(
|
||||||
executor.submit(
|
executor.submit(
|
||||||
_save_worker,
|
_save_worker,
|
||||||
|
subset,
|
||||||
batch_size,
|
batch_size,
|
||||||
cut_ids,
|
cut_ids,
|
||||||
audio,
|
audio,
|
||||||
@ -241,13 +244,25 @@ def main():
|
|||||||
test_cuts = vctk.test_cuts()
|
test_cuts = vctk.test_cuts()
|
||||||
test_dl = vctk.test_dataloaders(test_cuts)
|
test_dl = vctk.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
infer_dataset(
|
valid_cuts = vctk.valid_cuts()
|
||||||
dl=test_dl,
|
valid_dl = vctk.valid_dataloaders(valid_cuts)
|
||||||
params=params,
|
|
||||||
model=model,
|
infer_sets = {"test", test_dl, "valid", valid_dl}
|
||||||
tokenizer=tokenizer,
|
|
||||||
speaker_map=speaker_map,
|
for subset, dl in infer_sets.items():
|
||||||
)
|
save_wav_dir = params.res_dir / "wav" / subset
|
||||||
|
save_wav_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
logging.info(f"Processing {subset} set, saving to {save_wav_dir}")
|
||||||
|
|
||||||
|
infer_dataset(
|
||||||
|
dl=dl,
|
||||||
|
subset=subset,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
speaker_map=speaker_map,
|
||||||
|
)
|
||||||
|
|
||||||
logging.info(f"Wav files are saved to {params.save_wav_dir}")
|
logging.info(f"Wav files are saved to {params.save_wav_dir}")
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user