mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Add speaker embedding in text encoder
This commit is contained in:
parent
e5c04a216c
commit
a4e4f8080a
@ -77,6 +77,7 @@ class VITSGenerator(torch.nn.Module):
|
||||
noise_initial_mas: float = 0.01,
|
||||
noise_scale_mas: float = 2e-6,
|
||||
use_transformer_in_flows: bool = True,
|
||||
use_spk_conditioned_txt_encoder: bool = False,
|
||||
):
|
||||
"""Initialize VITS generator module.
|
||||
|
||||
@ -138,6 +139,13 @@ class VITSGenerator(torch.nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.segment_size = segment_size
|
||||
self.use_spk_conditioned_txt_encoder = use_spk_conditioned_txt_encoder
|
||||
|
||||
if self.use_spk_conditioned_txt_encoder and global_channels > 0:
|
||||
self.text_encoder_global_channels = global_channels
|
||||
else:
|
||||
self.text_encoder_global_channels = 0
|
||||
|
||||
self.text_encoder = TextEncoder(
|
||||
vocabs=vocabs,
|
||||
d_model=hidden_channels,
|
||||
@ -146,6 +154,7 @@ class VITSGenerator(torch.nn.Module):
|
||||
cnn_module_kernel=text_encoder_cnn_module_kernel,
|
||||
num_layers=text_encoder_blocks,
|
||||
dropout=text_encoder_dropout_rate,
|
||||
global_channels=self.text_encoder_global_channels,
|
||||
)
|
||||
self.decoder = HiFiGANGenerator(
|
||||
in_channels=hidden_channels,
|
||||
@ -286,9 +295,6 @@ class VITSGenerator(torch.nn.Module):
|
||||
- Tensor: Posterior encoder projected scale (B, H, T_feats).
|
||||
|
||||
"""
|
||||
# forward text encoder
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
|
||||
|
||||
# calculate global conditioning
|
||||
g = None
|
||||
if self.spks is not None:
|
||||
@ -309,6 +315,9 @@ class VITSGenerator(torch.nn.Module):
|
||||
else:
|
||||
g = g + g_
|
||||
|
||||
# forward text encoder
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths, g)
|
||||
|
||||
# forward posterior encoder
|
||||
z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
|
||||
|
||||
@ -376,7 +385,7 @@ class VITSGenerator(torch.nn.Module):
|
||||
logw_ = torch.log(w + 1e-6) * x_mask
|
||||
else:
|
||||
logw_ = torch.log(w + 1e-6) * x_mask
|
||||
logw = self.dp(x, x_mask, g=g)
|
||||
logw = self.duration_predictor(x, x_mask, g=g)
|
||||
dur_nll = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||
|
||||
# expand the length to match with the feature sequence
|
||||
|
@ -193,7 +193,7 @@ class FeatureMatchLoss(torch.nn.Module):
|
||||
from generator's outputs.
|
||||
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
||||
discriminator outputs or list of discriminator outputs calcuated
|
||||
from groundtruth..
|
||||
from groundtruth.
|
||||
|
||||
Returns:
|
||||
Tensor: Feature matching loss value.
|
||||
@ -333,30 +333,3 @@ class KLDivergenceLossWithoutFlow(torch.nn.Module):
|
||||
prior_norm = D.Normal(m_p, torch.exp(logs_p))
|
||||
loss = D.kl_divergence(posterior_norm, prior_norm).mean()
|
||||
return loss
|
||||
|
||||
|
||||
class DurationDiscLoss(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
disc_real_outputs: List[torch.Tensor],
|
||||
disc_generated_outputs: List[torch.Tensor],
|
||||
):
|
||||
loss = 0
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
dr = dr.float()
|
||||
dg = dg.float()
|
||||
r_loss = torch.mean((1 - dr) ** 2)
|
||||
g_loss = torch.mean(dg**2)
|
||||
loss += r_loss + g_loss
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class DurationGenLoss(torch.nn.Module):
|
||||
def forward(self, disc_outputs: List[torch.Tensor]):
|
||||
loss = 0
|
||||
for dg in disc_outputs:
|
||||
dg = dg.float()
|
||||
loss += torch.mean((1 - dg) ** 2)
|
||||
|
||||
return loss
|
||||
|
@ -49,6 +49,8 @@ class TextEncoder(torch.nn.Module):
|
||||
cnn_module_kernel: int = 5,
|
||||
num_layers: int = 6,
|
||||
dropout: float = 0.1,
|
||||
global_channels: int = 0,
|
||||
cond_layer_idx: int = 2,
|
||||
):
|
||||
"""Initialize TextEncoder module.
|
||||
|
||||
@ -63,6 +65,8 @@ class TextEncoder(torch.nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.global_channels = global_channels
|
||||
self.cond_layer_idx = cond_layer_idx
|
||||
|
||||
# define modules
|
||||
self.emb = torch.nn.Embedding(vocabs, d_model)
|
||||
@ -76,14 +80,14 @@ class TextEncoder(torch.nn.Module):
|
||||
cnn_module_kernel=cnn_module_kernel,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
global_channels=global_channels,
|
||||
cond_layer_idx=cond_layer_idx,
|
||||
)
|
||||
|
||||
self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lengths: torch.Tensor,
|
||||
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
@ -107,7 +111,7 @@ class TextEncoder(torch.nn.Module):
|
||||
pad_mask = make_pad_mask(x_lengths)
|
||||
|
||||
# encoder assume the channel last (B, T_text, embed_dim)
|
||||
x = self.encoder(x, key_padding_mask=pad_mask)
|
||||
x = self.encoder(x, key_padding_mask=pad_mask, g=g)
|
||||
|
||||
# convert the channel first (B, embed_dim, T_text)
|
||||
x = x.transpose(1, 2)
|
||||
@ -137,11 +141,18 @@ class Transformer(nn.Module):
|
||||
cnn_module_kernel: int = 5,
|
||||
num_layers: int = 6,
|
||||
dropout: float = 0.1,
|
||||
global_channels: int = 0,
|
||||
cond_layer_idx: int = 2,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.d_model = d_model
|
||||
self.global_channels = global_channels
|
||||
|
||||
speaker_embedder = None
|
||||
if self.global_channels != 0:
|
||||
speaker_embedder = nn.Linear(self.global_channels, self.d_model)
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
@ -151,12 +162,13 @@ class Transformer(nn.Module):
|
||||
dim_feedforward=dim_feedforward,
|
||||
cnn_module_kernel=cnn_module_kernel,
|
||||
dropout=dropout,
|
||||
speaker_embedder=speaker_embedder,
|
||||
)
|
||||
self.encoder = TransformerEncoder(encoder_layer, num_layers)
|
||||
self.encoder = TransformerEncoder(encoder_layer, num_layers, cond_layer_idx)
|
||||
self.after_norm = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(
|
||||
self, x: Tensor, key_padding_mask: Tensor
|
||||
self, x: Tensor, key_padding_mask: Tensor, g: Optional[Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -169,7 +181,9 @@ class Transformer(nn.Module):
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C)
|
||||
x = self.encoder(
|
||||
x, pos_emb, key_padding_mask=key_padding_mask, g=g
|
||||
) # (T, N, C)
|
||||
|
||||
x = self.after_norm(x)
|
||||
|
||||
@ -195,6 +209,7 @@ class TransformerEncoderLayer(nn.Module):
|
||||
dim_feedforward: int,
|
||||
cnn_module_kernel: int,
|
||||
dropout: float = 0.1,
|
||||
speaker_embedder: Optional[nn.Module] = None,
|
||||
) -> None:
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
|
||||
@ -227,11 +242,15 @@ class TransformerEncoderLayer(nn.Module):
|
||||
self.ff_scale = 0.5
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.speaker_embedder = speaker_embedder
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
g: Optional[Tensor] = None,
|
||||
apply_speaker_embedding: bool = False,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Pass the input through the transformer encoder layer.
|
||||
@ -241,6 +260,11 @@ class TransformerEncoderLayer(nn.Module):
|
||||
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
|
||||
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
|
||||
"""
|
||||
if g is not None and apply_speaker_embedding:
|
||||
g = self.speaker_embedder(g.transpose(1, 2))
|
||||
g = g.transpose(1, 2)
|
||||
src = src + g # * src_mask
|
||||
|
||||
# macaron style feed-forward module
|
||||
src = src + self.ff_scale * self.dropout(
|
||||
self.feed_forward_macaron(self.norm_ff_macaron(src))
|
||||
@ -273,19 +297,23 @@ class TransformerEncoder(nn.Module):
|
||||
num_layers: the number of sub-encoder-layers in the encoder.
|
||||
"""
|
||||
|
||||
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
||||
def __init__(
|
||||
self, encoder_layer: nn.Module, num_layers: int, cond_layer_idx: int = 0
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
self.cond_layer_idx = cond_layer_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
g: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
@ -295,12 +323,17 @@ class TransformerEncoder(nn.Module):
|
||||
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
|
||||
"""
|
||||
output = src
|
||||
|
||||
for layer_index, mod in enumerate(self.layers):
|
||||
apply_speaker_embedding = False
|
||||
if layer_index == self.cond_layer_idx:
|
||||
apply_speaker_embedding = True
|
||||
|
||||
output = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
key_padding_mask=key_padding_mask,
|
||||
g=g,
|
||||
apply_speaker_embedding=apply_speaker_embedding,
|
||||
)
|
||||
|
||||
return output
|
||||
|
@ -375,8 +375,10 @@ def train_one_epoch(
|
||||
params=params,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
optimizer_dur=optimizer_dur,
|
||||
scheduler_g=scheduler_g,
|
||||
scheduler_d=scheduler_d,
|
||||
scheduler_dur=scheduler_dur,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=0,
|
||||
@ -393,15 +395,42 @@ def train_one_epoch(
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
|
||||
if model.module.generator.use_noised_mas:
|
||||
# MAS with Gaussian Noise
|
||||
model.module.generator.noise_current_mas = max(
|
||||
model.module.generator.noise_initial_mas
|
||||
- model.module.generator.noise_scale_mas * params.batch_idx_train,
|
||||
0.0,
|
||||
)
|
||||
if isinstance(model, DDP):
|
||||
if model.module.generator.use_noised_mas:
|
||||
# MAS with Gaussian Noise
|
||||
model.module.generator.noise_current_mas = max(
|
||||
model.module.generator.noise_initial_mas
|
||||
- model.module.generator.noise_scale_mas * params.batch_idx_train,
|
||||
0.0,
|
||||
)
|
||||
else:
|
||||
if model.generator.use_noised_mas:
|
||||
# MAS with Gaussian Noise
|
||||
model.generator.noise_current_mas = max(
|
||||
model.generator.noise_initial_mas
|
||||
- model.generator.noise_scale_mas * params.batch_idx_train,
|
||||
0.0,
|
||||
)
|
||||
|
||||
try:
|
||||
with autocast(enabled=params.use_fp16):
|
||||
# forward duration discriminator
|
||||
dur_loss, stats_dur = model(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
feats=features,
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_type="duration_discriminator",
|
||||
)
|
||||
for k, v in stats_dur.items():
|
||||
loss_info[k] = v * batch_size
|
||||
|
||||
optimizer_dur.zero_grad()
|
||||
scaler.scale(dur_loss).backward()
|
||||
scaler.step(optimizer_dur)
|
||||
|
||||
with autocast(enabled=params.use_fp16):
|
||||
# forward discriminator
|
||||
loss_d, dur_loss, stats_d = model(
|
||||
@ -411,13 +440,9 @@ def train_one_epoch(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=False,
|
||||
forward_type="discriminator",
|
||||
)
|
||||
|
||||
optimizer_dur.zero_grad()
|
||||
scaler.scale(dur_loss).backward()
|
||||
scaler.step(optimizer_dur)
|
||||
|
||||
for k, v in stats_d.items():
|
||||
loss_info[k] = v * batch_size
|
||||
# update discriminator
|
||||
@ -434,7 +459,7 @@ def train_one_epoch(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=True,
|
||||
forward_type="generator",
|
||||
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||
)
|
||||
for k, v in stats_g.items():
|
||||
@ -603,15 +628,29 @@ def compute_validation_loss(
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
|
||||
# forward discriminator
|
||||
loss_d, dur_loss, stats_d = model(
|
||||
# forward duration discriminator
|
||||
loss_dur, stats_dur = model(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
feats=features,
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=False,
|
||||
forward_type="duration_discriminator",
|
||||
)
|
||||
assert loss_dur.requires_grad is False
|
||||
for k, v in stats_dur.items():
|
||||
loss_info[k] = v * batch_size
|
||||
|
||||
# forward discriminator
|
||||
loss_d, stats_d = model(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
feats=features,
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_type="discriminator",
|
||||
)
|
||||
assert loss_d.requires_grad is False
|
||||
for k, v in stats_d.items():
|
||||
@ -625,7 +664,7 @@ def compute_validation_loss(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=True,
|
||||
forward_type="generator",
|
||||
)
|
||||
assert loss_g.requires_grad is False
|
||||
for k, v in stats_g.items():
|
||||
@ -684,21 +723,33 @@ def scan_pessimistic_batches_for_oom(
|
||||
batch, tokenizer, device
|
||||
)
|
||||
try:
|
||||
# for discriminator
|
||||
# for duration discriminator
|
||||
with autocast(enabled=params.use_fp16):
|
||||
loss_d, dur_loss, stats_d = model(
|
||||
dur_loss, stats_dur = model(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
feats=features,
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=False,
|
||||
forward_type="duration_discriminator",
|
||||
)
|
||||
|
||||
optimizer_dur.zero_grad()
|
||||
dur_loss.backward()
|
||||
|
||||
# for discriminator
|
||||
with autocast(enabled=params.use_fp16):
|
||||
loss_d, stats_d = model(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
feats=features,
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_type="discriminator",
|
||||
)
|
||||
|
||||
optimizer_d.zero_grad()
|
||||
loss_d.backward()
|
||||
# for generator
|
||||
@ -710,7 +761,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=True,
|
||||
forward_type="generator",
|
||||
)
|
||||
optimizer_g.zero_grad()
|
||||
loss_g.backward()
|
||||
@ -918,8 +969,10 @@ def run(rank, world_size, args):
|
||||
model=model,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
optimizer_dur=optimizer_dur,
|
||||
scheduler_g=scheduler_g,
|
||||
scheduler_d=scheduler_d,
|
||||
scheduler_dur=scheduler_dur,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
|
@ -203,8 +203,10 @@ def save_checkpoint(
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
optimizer_g: Optional[Optimizer] = None,
|
||||
optimizer_d: Optional[Optimizer] = None,
|
||||
optimizer_dur: Optional[Optimizer] = None,
|
||||
scheduler_g: Optional[LRSchedulerType] = None,
|
||||
scheduler_d: Optional[LRSchedulerType] = None,
|
||||
scheduler_dur: Optional[LRSchedulerType] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
@ -251,7 +253,13 @@ def save_checkpoint(
|
||||
"model": model.state_dict(),
|
||||
"optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None,
|
||||
"optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None,
|
||||
"optimizer_dur": optimizer_dur.state_dict()
|
||||
if optimizer_d is not None
|
||||
else None,
|
||||
"scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None,
|
||||
"scheduler_dur": scheduler_dur.state_dict()
|
||||
if scheduler_d is not None
|
||||
else None,
|
||||
"scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None,
|
||||
"grad_scaler": scaler.state_dict() if scaler is not None else None,
|
||||
"sampler": sampler.state_dict() if sampler is not None else None,
|
||||
|
@ -20,8 +20,6 @@ from hifigan import (
|
||||
)
|
||||
from loss import (
|
||||
DiscriminatorAdversarialLoss,
|
||||
DurationDiscLoss,
|
||||
DurationGenLoss,
|
||||
FeatureMatchLoss,
|
||||
GeneratorAdversarialLoss,
|
||||
KLDivergenceLoss,
|
||||
@ -236,10 +234,6 @@ class VITS(nn.Module):
|
||||
)
|
||||
self.kl_loss = KLDivergenceLoss()
|
||||
|
||||
# Vits2 duration disc
|
||||
self.dur_disc_loss = DurationDiscLoss()
|
||||
self.dur_gen_loss = DurationGenLoss()
|
||||
|
||||
# coefficients
|
||||
self.lambda_adv = lambda_adv
|
||||
self.lambda_mel = lambda_mel
|
||||
@ -273,7 +267,7 @@ class VITS(nn.Module):
|
||||
sids: Optional[torch.Tensor] = None,
|
||||
spembs: Optional[torch.Tensor] = None,
|
||||
lids: Optional[torch.Tensor] = None,
|
||||
forward_generator: bool = True,
|
||||
forward_type: str = "generator",
|
||||
) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||
"""Perform generator forward.
|
||||
|
||||
@ -287,13 +281,13 @@ class VITS(nn.Module):
|
||||
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||
forward_generator (bool): Whether to forward generator.
|
||||
forward_type (str): Which type of forward to do
|
||||
|
||||
Returns:
|
||||
- loss (Tensor): Loss scalar tensor.
|
||||
- stats (Dict[str, float]): Statistics to be monitored.
|
||||
"""
|
||||
if forward_generator:
|
||||
if forward_type == "generator":
|
||||
return self._forward_generator(
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
@ -306,7 +300,7 @@ class VITS(nn.Module):
|
||||
spembs=spembs,
|
||||
lids=lids,
|
||||
)
|
||||
else:
|
||||
elif forward_type == "discriminator":
|
||||
return self._forward_discrminator(
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
@ -318,6 +312,20 @@ class VITS(nn.Module):
|
||||
spembs=spembs,
|
||||
lids=lids,
|
||||
)
|
||||
elif forward_type == "duration_discriminator":
|
||||
return self._forward_discrminator_duration(
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
feats=feats,
|
||||
feats_lengths=feats_lengths,
|
||||
speech=speech,
|
||||
speech_lengths=speech_lengths,
|
||||
sids=sids,
|
||||
spembs=spembs,
|
||||
lids=lids,
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Forward type {forward_type} does not exist")
|
||||
|
||||
def _forward_generator(
|
||||
self,
|
||||
@ -374,8 +382,6 @@ class VITS(nn.Module):
|
||||
self._cache = outs
|
||||
|
||||
# parse outputs
|
||||
# speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
|
||||
# _, z_p, m_p, logs_p, _, logs_q = outs_
|
||||
(
|
||||
speech_hat_,
|
||||
dur_nll,
|
||||
@ -412,7 +418,7 @@ class VITS(nn.Module):
|
||||
feat_match_loss = self.feat_match_loss(p_hat, p)
|
||||
|
||||
y_dur_hat_r, y_dur_hat_g = self.dur_disc(hidden_x, x_mask, logw_, logw)
|
||||
dur_gen_loss = self.dur_gen_loss(y_dur_hat_g)
|
||||
dur_gen_loss = self.generator_adv_loss(y_dur_hat_g)
|
||||
|
||||
mel_loss = mel_loss * self.lambda_mel
|
||||
kl_loss = kl_loss * self.lambda_kl
|
||||
@ -514,8 +520,8 @@ class VITS(nn.Module):
|
||||
start_idxs,
|
||||
x_mask,
|
||||
y_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
(hidden_x, logw, logw_),
|
||||
_,
|
||||
_,
|
||||
) = outs
|
||||
|
||||
speech_ = get_segments(
|
||||
@ -533,14 +539,6 @@ class VITS(nn.Module):
|
||||
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
|
||||
loss = real_loss + fake_loss
|
||||
|
||||
# Duration Discriminator
|
||||
y_dur_hat_r, y_dur_hat_g = self.dur_disc(
|
||||
hidden_x.detach(), x_mask.detach(), logw_.detach(), logw.detach()
|
||||
)
|
||||
|
||||
with autocast(enabled=False):
|
||||
dur_loss = self.dur_disc_loss(y_dur_hat_r, y_dur_hat_g)
|
||||
|
||||
stats = dict(
|
||||
discriminator_loss=loss.item(),
|
||||
discriminator_real_loss=real_loss.item(),
|
||||
@ -551,7 +549,92 @@ class VITS(nn.Module):
|
||||
if reuse_cache or not self.training:
|
||||
self._cache = None
|
||||
|
||||
return loss, dur_loss, stats
|
||||
return loss, stats
|
||||
|
||||
def _forward_discrminator_duration(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
feats: torch.Tensor,
|
||||
feats_lengths: torch.Tensor,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
sids: Optional[torch.Tensor] = None,
|
||||
spembs: Optional[torch.Tensor] = None,
|
||||
lids: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||
"""Perform discriminator forward.
|
||||
|
||||
Args:
|
||||
text (Tensor): Text index tensor (B, T_text).
|
||||
text_lengths (Tensor): Text length tensor (B,).
|
||||
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
|
||||
feats_lengths (Tensor): Feature length tensor (B,).
|
||||
speech (Tensor): Speech waveform tensor (B, T_wav).
|
||||
speech_lengths (Tensor): Speech length tensor (B,).
|
||||
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||
|
||||
Returns:
|
||||
* loss (Tensor): Loss scalar tensor.
|
||||
* stats (Dict[str, float]): Statistics to be monitored.
|
||||
"""
|
||||
# setup
|
||||
feats = feats.transpose(1, 2)
|
||||
speech = speech.unsqueeze(1)
|
||||
|
||||
# calculate generator outputs
|
||||
reuse_cache = True
|
||||
if not self.cache_generator_outputs or self._cache is None:
|
||||
reuse_cache = False
|
||||
outs = self.generator(
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
feats=feats,
|
||||
feats_lengths=feats_lengths,
|
||||
sids=sids,
|
||||
spembs=spembs,
|
||||
lids=lids,
|
||||
)
|
||||
else:
|
||||
outs = self._cache
|
||||
|
||||
# store cache
|
||||
if self.cache_generator_outputs and not reuse_cache:
|
||||
self._cache = outs
|
||||
|
||||
(
|
||||
speech_hat_,
|
||||
dur_nll,
|
||||
attn,
|
||||
start_idxs,
|
||||
x_mask,
|
||||
y_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
(hidden_x, logw, logw_),
|
||||
) = outs
|
||||
|
||||
# Duration Discriminator
|
||||
y_dur_hat_r, y_dur_hat_g = self.dur_disc(
|
||||
hidden_x.detach(), x_mask.detach(), logw_.detach(), logw.detach()
|
||||
)
|
||||
|
||||
with autocast(enabled=False):
|
||||
real_dur_loss, fake_dur_loss = self.discriminator_adv_loss(
|
||||
y_dur_hat_g, y_dur_hat_r
|
||||
)
|
||||
dur_loss = real_dur_loss + fake_dur_loss
|
||||
|
||||
stats = dict(
|
||||
discriminator_dur_loss=dur_loss.item(),
|
||||
)
|
||||
|
||||
# reset cache
|
||||
if reuse_cache or not self.training:
|
||||
self._cache = None
|
||||
|
||||
return dur_loss, stats
|
||||
|
||||
def inference(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user