Update vits.py

This commit is contained in:
zr_jin 2024-10-22 12:34:07 +08:00
parent 3ac1331b27
commit 32cdbdfebb

View File

@ -623,6 +623,7 @@ class VITS(nn.Module):
text_lengths: torch.Tensor, text_lengths: torch.Tensor,
sids: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
durations: Optional[torch.Tensor] = None, durations: Optional[torch.Tensor] = None,
noise_scale: float = 0.667, noise_scale: float = 0.667,
noise_scale_dur: float = 0.8, noise_scale_dur: float = 0.8,
@ -637,6 +638,7 @@ class VITS(nn.Module):
text_lengths (Tensor): Input text index tensor (B,). text_lengths (Tensor): Input text index tensor (B,).
sids (Tensor): Speaker index tensor (B,). sids (Tensor): Speaker index tensor (B,).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Tensor): Language index tensor (B,).
noise_scale (float): Noise scale value for flow. noise_scale (float): Noise scale value for flow.
noise_scale_dur (float): Noise scale value for duration predictor. noise_scale_dur (float): Noise scale value for duration predictor.
alpha (float): Alpha parameter to control the speed of generated speech. alpha (float): Alpha parameter to control the speed of generated speech.
@ -653,6 +655,7 @@ class VITS(nn.Module):
text_lengths=text_lengths, text_lengths=text_lengths,
sids=sids, sids=sids,
spembs=spembs, spembs=spembs,
lids=lids,
noise_scale=noise_scale, noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur, noise_scale_dur=noise_scale_dur,
alpha=alpha, alpha=alpha,