mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
Update infer.py
This commit is contained in:
parent
52fe6f9bfd
commit
ddecda0c9b
@ -78,6 +78,7 @@ def get_parser():
|
||||
|
||||
def infer_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
subset: str,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
tokenizer: Tokenizer,
|
||||
@ -99,6 +100,7 @@ def infer_dataset(
|
||||
|
||||
# Background worker save audios to disk.
|
||||
def _save_worker(
|
||||
subset: str,
|
||||
batch_size: int,
|
||||
cut_ids: List[str],
|
||||
audio: torch.Tensor,
|
||||
@ -108,12 +110,12 @@ def infer_dataset(
|
||||
):
|
||||
for i in range(batch_size):
|
||||
torchaudio.save(
|
||||
str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
|
||||
str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"),
|
||||
audio[i : i + 1, : audio_lens[i]],
|
||||
sample_rate=params.sampling_rate,
|
||||
)
|
||||
torchaudio.save(
|
||||
str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
|
||||
str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"),
|
||||
audio_pred[i : i + 1, : audio_lens_pred[i]],
|
||||
sample_rate=params.sampling_rate,
|
||||
)
|
||||
@ -165,6 +167,7 @@ def infer_dataset(
|
||||
futures.append(
|
||||
executor.submit(
|
||||
_save_worker,
|
||||
subset,
|
||||
batch_size,
|
||||
cut_ids,
|
||||
audio,
|
||||
@ -241,13 +244,25 @@ def main():
|
||||
test_cuts = vctk.test_cuts()
|
||||
test_dl = vctk.test_dataloaders(test_cuts)
|
||||
|
||||
infer_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
speaker_map=speaker_map,
|
||||
)
|
||||
valid_cuts = vctk.valid_cuts()
|
||||
valid_dl = vctk.valid_dataloaders(valid_cuts)
|
||||
|
||||
infer_sets = {"test", test_dl, "valid", valid_dl}
|
||||
|
||||
for subset, dl in infer_sets.items():
|
||||
save_wav_dir = params.res_dir / "wav" / subset
|
||||
save_wav_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.info(f"Processing {subset} set, saving to {save_wav_dir}")
|
||||
|
||||
infer_dataset(
|
||||
dl=dl,
|
||||
subset=subset,
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
speaker_map=speaker_map,
|
||||
)
|
||||
|
||||
logging.info(f"Wav files are saved to {params.save_wav_dir}")
|
||||
logging.info("Done!")
|
||||
|
Loading…
x
Reference in New Issue
Block a user