diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py index a7c4a4c09..91a35e360 100755 --- a/egs/ljspeech/TTS/vits/infer.py +++ b/egs/ljspeech/TTS/vits/infer.py @@ -225,6 +225,7 @@ def main(): tokenizer=tokenizer, ) + logging.info(f"Wav files are saved to {params.save_wav_dir}") logging.info("Done!") diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py index aa26a012d..d5e20a578 100644 --- a/egs/ljspeech/TTS/vits/vits.py +++ b/egs/ljspeech/TTS/vits/vits.py @@ -573,6 +573,7 @@ class VITS(nn.Module): self, text: torch.Tensor, text_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, durations: Optional[torch.Tensor] = None, noise_scale: float = 0.667, noise_scale_dur: float = 0.8, @@ -585,6 +586,7 @@ class VITS(nn.Module): Args: text (Tensor): Input text index tensor (B, T_text). text_lengths (Tensor): Input text index tensor (B,). + sids (Tensor): Speaker index tensor (B,). noise_scale (float): Noise scale value for flow. noise_scale_dur (float): Noise scale value for duration predictor. alpha (float): Alpha parameter to control the speed of generated speech. @@ -599,6 +601,7 @@ class VITS(nn.Module): wav, att_w, dur = self.generator.inference( text=text, text_lengths=text_lengths, + sids=sids, noise_scale=noise_scale, noise_scale_dur=noise_scale_dur, alpha=alpha,