Add duration discrimination loss

This commit is contained in:
Erwan 2024-02-09 11:08:30 +01:00
parent b9fdebaff2
commit e5c04a216c
7 changed files with 372 additions and 25 deletions

View File

@ -20,6 +20,7 @@ from flow import (
ElementwiseAffineFlow, ElementwiseAffineFlow,
FlipFlow, FlipFlow,
LogFlow, LogFlow,
Transpose,
) )
@ -191,3 +192,68 @@ class StochasticDurationPredictor(torch.nn.Module):
z0, z1 = z.split(1, 1) z0, z1 = z.split(1, 1)
logw = z0 logw = z0
return logw return logw
class DurationPredictor(torch.nn.Module):
def __init__(
self,
input_channels: int = 192,
output_channels: int = 192,
kernel_size: int = 3,
dropout_rate: float = 0.5,
global_channels: int = -1,
eps: float = 1e-5,
):
super().__init__()
self.input_channels = input_channels
self.output_channels = output_channels
self.kernel_size = kernel_size
self.dropout_rate = dropout_rate
self.gin_channels = global_channels
self.dropout = torch.nn.Dropout(dropout_rate)
self.conv_1 = torch.nn.Conv1d(
input_channels, output_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
output_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.conv_2 = torch.nn.Conv1d(
output_channels, output_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
output_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.proj = torch.nn.Conv1d(output_channels, 1, 1)
if global_channels > 0:
self.cond = torch.nn.Conv1d(global_channels, input_channels, 1)
def forward(self, x, x_mask, g=None):
x = torch.detach(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.dropout(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.dropout(x)
x = self.proj(x * x_mask)
return x * x_mask

View File

@ -16,7 +16,7 @@ from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from duration_predictor import StochasticDurationPredictor from duration_predictor import DurationPredictor, StochasticDurationPredictor
from hifigan import HiFiGANGenerator from hifigan import HiFiGANGenerator
from posterior_encoder import PosteriorEncoder from posterior_encoder import PosteriorEncoder
from residual_coupling import ResidualAffineCouplingBlock from residual_coupling import ResidualAffineCouplingBlock
@ -71,6 +71,8 @@ class VITSGenerator(torch.nn.Module):
stochastic_duration_predictor_dropout_rate: float = 0.5, stochastic_duration_predictor_dropout_rate: float = 0.5,
stochastic_duration_predictor_flows: int = 4, stochastic_duration_predictor_flows: int = 4,
stochastic_duration_predictor_dds_conv_layers: int = 3, stochastic_duration_predictor_dds_conv_layers: int = 3,
duration_predictor_output_channels: int = 256,
use_stochastic_duration_predictor: bool = True,
use_noised_mas: bool = True, use_noised_mas: bool = True,
noise_initial_mas: float = 0.01, noise_initial_mas: float = 0.01,
noise_scale_mas: float = 2e-6, noise_scale_mas: float = 2e-6,
@ -184,14 +186,23 @@ class VITSGenerator(torch.nn.Module):
use_transformer_in_flows=use_transformer_in_flows, use_transformer_in_flows=use_transformer_in_flows,
) )
# TODO(kan-bayashi): Add deterministic version as an option # TODO(kan-bayashi): Add deterministic version as an option
self.duration_predictor = StochasticDurationPredictor( if use_stochastic_duration_predictor:
channels=hidden_channels, self.duration_predictor = StochasticDurationPredictor(
kernel_size=stochastic_duration_predictor_kernel_size, channels=hidden_channels,
dropout_rate=stochastic_duration_predictor_dropout_rate, kernel_size=stochastic_duration_predictor_kernel_size,
flows=stochastic_duration_predictor_flows, dropout_rate=stochastic_duration_predictor_dropout_rate,
dds_conv_layers=stochastic_duration_predictor_dds_conv_layers, flows=stochastic_duration_predictor_flows,
global_channels=global_channels, dds_conv_layers=stochastic_duration_predictor_dds_conv_layers,
) global_channels=global_channels,
)
else:
self.duration_predictor = DurationPredictor(
input_channels=hidden_channels,
output_channels=duration_predictor_output_channels,
kernel_size=stochastic_duration_predictor_kernel_size,
dropout_rate=stochastic_duration_predictor_dropout_rate,
global_channels=global_channels,
)
self.upsample_factor = int(np.prod(decoder_upsample_scales)) self.upsample_factor = int(np.prod(decoder_upsample_scales))
@ -200,6 +211,7 @@ class VITSGenerator(torch.nn.Module):
self.noise_current_mas = noise_initial_mas self.noise_current_mas = noise_initial_mas
self.noise_scale_mas = noise_scale_mas self.noise_scale_mas = noise_scale_mas
self.noise_initial_mas = noise_initial_mas self.noise_initial_mas = noise_initial_mas
self.use_stochastic_duration_predictor = use_stochastic_duration_predictor
self.spks = None self.spks = None
if spks is not None and spks > 1: if spks is not None and spks > 1:
@ -354,8 +366,18 @@ class VITSGenerator(torch.nn.Module):
# forward duration predictor # forward duration predictor
w = attn.sum(2) # (B, 1, T_text) w = attn.sum(2) # (B, 1, T_text)
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
dur_nll = dur_nll / torch.sum(x_mask) if self.use_stochastic_duration_predictor:
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
dur_nll = dur_nll / torch.sum(x_mask)
logw = self.duration_predictor(
x, x_mask, g=g, inverse=True, noise_scale=1.0
)
logw_ = torch.log(w + 1e-6) * x_mask
else:
logw_ = torch.log(w + 1e-6) * x_mask
logw = self.dp(x, x_mask, g=g)
dur_nll = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
# expand the length to match with the feature sequence # expand the length to match with the feature sequence
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
@ -381,6 +403,7 @@ class VITSGenerator(torch.nn.Module):
x_mask, x_mask,
y_mask, y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q), (z, z_p, m_p, logs_p, m_q, logs_q),
(x, logw, logw_),
) )
def inference( def inference(

View File

@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from flow import Transpose
class HiFiGANGenerator(torch.nn.Module): class HiFiGANGenerator(torch.nn.Module):
@ -931,3 +932,136 @@ class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module):
msd_outs = self.msd(x) msd_outs = self.msd(x)
mpd_outs = self.mpd(x) mpd_outs = self.mpd(x)
return msd_outs + mpd_outs return msd_outs + mpd_outs
class DurationDiscriminator(torch.nn.Module): # vits2
def __init__(
self,
channels: int = 192,
hidden_channels: int = 192,
kernel_size: int = 3,
dropout_rate: float = 0.5,
eps: float = 1e-5,
global_channels: int = -1,
):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dropout_rate = dropout_rate
self.global_channels = global_channels
self.dropout = torch.nn.Dropout(dropout_rate)
self.conv_1 = torch.nn.Conv1d(
channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.conv_2 = torch.nn.Conv1d(
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.dur_proj = torch.nn.Conv1d(1, hidden_channels, 1)
self.pre_out_conv_1 = torch.nn.Conv1d(
2 * hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.pre_out_norm_1 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.pre_out_conv_2 = torch.nn.Conv1d(
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.pre_out_norm_2 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
if global_channels > 0:
self.cond_layer = torch.nn.Conv1d(global_channels, channels, 1)
self.output_layer = torch.nn.Sequential(
torch.nn.Linear(hidden_channels, 1), torch.nn.Sigmoid()
)
def forward_probability(
self, x: torch.Tensor, x_mask: torch.Tensor, dur: torch.Tensor
):
dur = self.dur_proj(dur)
x = torch.cat([x, dur], dim=1)
x = self.pre_out_conv_1(x * x_mask)
x = torch.relu(x)
x = self.pre_out_norm_1(x)
x = self.dropout(x)
x = self.pre_out_conv_2(x * x_mask)
x = torch.relu(x)
x = self.pre_out_norm_2(x)
x = self.dropout(x)
x = x * x_mask
x = x.transpose(1, 2)
output_prob = self.output_layer(x)
return output_prob
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
dur_r: torch.Tensor,
dur_hat: torch.Tensor,
g: Optional[torch.Tensor] = None,
):
x = torch.detach(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond_layer(g)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.dropout(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.dropout(x)
output_probs = []
for dur in [dur_r, dur_hat]:
output_prob = self.forward_probability(x, x_mask, dur)
output_probs.append(output_prob)
return output_probs

View File

@ -333,3 +333,30 @@ class KLDivergenceLossWithoutFlow(torch.nn.Module):
prior_norm = D.Normal(m_p, torch.exp(logs_p)) prior_norm = D.Normal(m_p, torch.exp(logs_p))
loss = D.kl_divergence(posterior_norm, prior_norm).mean() loss = D.kl_divergence(posterior_norm, prior_norm).mean()
return loss return loss
class DurationDiscLoss(torch.nn.Module):
def forward(
self,
disc_real_outputs: List[torch.Tensor],
disc_generated_outputs: List[torch.Tensor],
):
loss = 0
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
dr = dr.float()
dg = dg.float()
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
loss += r_loss + g_loss
return loss
class DurationGenLoss(torch.nn.Module):
def forward(self, disc_outputs: List[torch.Tensor]):
loss = 0
for dg in disc_outputs:
dg = dg.float()
loss += torch.mean((1 - dg) ** 2)
return loss

View File

@ -314,8 +314,10 @@ def train_one_epoch(
tokenizer: Tokenizer, tokenizer: Tokenizer,
optimizer_g: Optimizer, optimizer_g: Optimizer,
optimizer_d: Optimizer, optimizer_d: Optimizer,
optimizer_dur: Optimizer,
scheduler_g: LRSchedulerType, scheduler_g: LRSchedulerType,
scheduler_d: LRSchedulerType, scheduler_d: LRSchedulerType,
scheduler_dur: LRSchedulerType,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: GradScaler,
@ -402,7 +404,7 @@ def train_one_epoch(
try: try:
with autocast(enabled=params.use_fp16): with autocast(enabled=params.use_fp16):
# forward discriminator # forward discriminator
loss_d, stats_d = model( loss_d, dur_loss, stats_d = model(
text=tokens, text=tokens,
text_lengths=tokens_lens, text_lengths=tokens_lens,
feats=features, feats=features,
@ -411,6 +413,11 @@ def train_one_epoch(
speech_lengths=audio_lens, speech_lengths=audio_lens,
forward_generator=False, forward_generator=False,
) )
optimizer_dur.zero_grad()
scaler.scale(dur_loss).backward()
scaler.step(optimizer_dur)
for k, v in stats_d.items(): for k, v in stats_d.items():
loss_info[k] = v * batch_size loss_info[k] = v * batch_size
# update discriminator # update discriminator
@ -597,7 +604,7 @@ def compute_validation_loss(
loss_info["samples"] = batch_size loss_info["samples"] = batch_size
# forward discriminator # forward discriminator
loss_d, stats_d = model( loss_d, dur_loss, stats_d = model(
text=tokens, text=tokens,
text_lengths=tokens_lens, text_lengths=tokens_lens,
feats=features, feats=features,
@ -661,6 +668,7 @@ def scan_pessimistic_batches_for_oom(
tokenizer: Tokenizer, tokenizer: Tokenizer,
optimizer_g: torch.optim.Optimizer, optimizer_g: torch.optim.Optimizer,
optimizer_d: torch.optim.Optimizer, optimizer_d: torch.optim.Optimizer,
optimizer_dur: torch.optim.Optimizer,
params: AttributeDict, params: AttributeDict,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -678,7 +686,7 @@ def scan_pessimistic_batches_for_oom(
try: try:
# for discriminator # for discriminator
with autocast(enabled=params.use_fp16): with autocast(enabled=params.use_fp16):
loss_d, stats_d = model( loss_d, dur_loss, stats_d = model(
text=tokens, text=tokens,
text_lengths=tokens_lens, text_lengths=tokens_lens,
feats=features, feats=features,
@ -687,6 +695,10 @@ def scan_pessimistic_batches_for_oom(
speech_lengths=audio_lens, speech_lengths=audio_lens,
forward_generator=False, forward_generator=False,
) )
optimizer_dur.zero_grad()
dur_loss.backward()
optimizer_d.zero_grad() optimizer_d.zero_grad()
loss_d.backward() loss_d.backward()
# for generator # for generator
@ -760,12 +772,17 @@ def run(rank, world_size, args):
model = get_model(params) model = get_model(params)
generator = model.generator generator = model.generator
discriminator = model.discriminator discriminator = model.discriminator
dur_disc = model.dur_disc
num_param_g = sum([p.numel() for p in generator.parameters()]) num_param_g = sum([p.numel() for p in generator.parameters()])
logging.info(f"Number of parameters in generator: {num_param_g}") logging.info(f"Number of parameters in generator: {num_param_g}")
num_param_d = sum([p.numel() for p in discriminator.parameters()]) num_param_d = sum([p.numel() for p in discriminator.parameters()])
logging.info(f"Number of parameters in discriminator: {num_param_d}") logging.info(f"Number of parameters in discriminator: {num_param_d}")
logging.info(f"Total number of parameters: {num_param_g + num_param_d}") num_param_dur = sum([p.numel() for p in dur_disc.parameters()])
logging.info(f"Number of parameters in duration discriminator: {num_param_dur}")
logging.info(
f"Total number of parameters: {num_param_g + num_param_d + num_param_dur}"
)
assert params.start_epoch > 0, params.start_epoch assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(params=params, model=model) checkpoints = load_checkpoint_if_available(params=params, model=model)
@ -781,9 +798,15 @@ def run(rank, world_size, args):
optimizer_d = torch.optim.AdamW( optimizer_d = torch.optim.AdamW(
discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
) )
optimizer_dur = torch.optim.AdamW(
dur_disc.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
scheduler_dur = torch.optim.lr_scheduler.ExponentialLR(
optimizer_dur, gamma=0.999875
)
if checkpoints is not None: if checkpoints is not None:
# load state_dict for optimizers # load state_dict for optimizers
@ -793,6 +816,9 @@ def run(rank, world_size, args):
if "optimizer_d" in checkpoints: if "optimizer_d" in checkpoints:
logging.info("Loading optimizer_d state dict") logging.info("Loading optimizer_d state dict")
optimizer_d.load_state_dict(checkpoints["optimizer_d"]) optimizer_d.load_state_dict(checkpoints["optimizer_d"])
if "optimizer_dur" in checkpoints:
logging.info("Loading optimizer_dur state dict")
optimizer_dur.load_state_dict(checkpoints["optimizer_dur"])
# load state_dict for schedulers # load state_dict for schedulers
if "scheduler_g" in checkpoints: if "scheduler_g" in checkpoints:
@ -801,6 +827,9 @@ def run(rank, world_size, args):
if "scheduler_d" in checkpoints: if "scheduler_d" in checkpoints:
logging.info("Loading scheduler_d state dict") logging.info("Loading scheduler_d state dict")
scheduler_d.load_state_dict(checkpoints["scheduler_d"]) scheduler_d.load_state_dict(checkpoints["scheduler_d"])
if "scheduler_dur" in checkpoints:
logging.info("Loading scheduler_dur state dict")
scheduler_dur.load_state_dict(checkpoints["scheduler_dur"])
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
@ -812,7 +841,6 @@ def run(rank, world_size, args):
register_inf_check_hooks(model) register_inf_check_hooks(model)
ljspeech = LJSpeechTtsDataModule(args) ljspeech = LJSpeechTtsDataModule(args)
train_cuts = ljspeech.train_cuts() train_cuts = ljspeech.train_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
@ -840,6 +868,7 @@ def run(rank, world_size, args):
tokenizer=tokenizer, tokenizer=tokenizer,
optimizer_g=optimizer_g, optimizer_g=optimizer_g,
optimizer_d=optimizer_d, optimizer_d=optimizer_d,
optimizer_dur=optimizer_dur,
params=params, params=params,
) )
@ -865,8 +894,10 @@ def run(rank, world_size, args):
tokenizer=tokenizer, tokenizer=tokenizer,
optimizer_g=optimizer_g, optimizer_g=optimizer_g,
optimizer_d=optimizer_d, optimizer_d=optimizer_d,
optimizer_dur=optimizer_dur,
scheduler_g=scheduler_g, scheduler_g=scheduler_g,
scheduler_d=scheduler_d, scheduler_d=scheduler_d,
scheduler_dur=scheduler_dur,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
scaler=scaler, scaler=scaler,
@ -905,6 +936,7 @@ def run(rank, world_size, args):
# step per epoch # step per epoch
scheduler_g.step() scheduler_g.step()
scheduler_d.step() scheduler_d.step()
scheduler_dur.step()
logging.info("Done!") logging.info("Done!")

View File

@ -11,6 +11,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from generator import VITSGenerator from generator import VITSGenerator
from hifigan import ( from hifigan import (
DurationDiscriminator,
HiFiGANMultiPeriodDiscriminator, HiFiGANMultiPeriodDiscriminator,
HiFiGANMultiScaleDiscriminator, HiFiGANMultiScaleDiscriminator,
HiFiGANMultiScaleMultiPeriodDiscriminator, HiFiGANMultiScaleMultiPeriodDiscriminator,
@ -19,6 +20,8 @@ from hifigan import (
) )
from loss import ( from loss import (
DiscriminatorAdversarialLoss, DiscriminatorAdversarialLoss,
DurationDiscLoss,
DurationGenLoss,
FeatureMatchLoss, FeatureMatchLoss,
GeneratorAdversarialLoss, GeneratorAdversarialLoss,
KLDivergenceLoss, KLDivergenceLoss,
@ -87,6 +90,8 @@ class VITS(nn.Module):
"stochastic_duration_predictor_dropout_rate": 0.5, "stochastic_duration_predictor_dropout_rate": 0.5,
"stochastic_duration_predictor_flows": 4, "stochastic_duration_predictor_flows": 4,
"stochastic_duration_predictor_dds_conv_layers": 3, "stochastic_duration_predictor_dds_conv_layers": 3,
"duration_predictor_output_channels": 256,
"use_stochastic_duration_predictor": True,
"use_noised_mas": True, "use_noised_mas": True,
"noise_initial_mas": 0.01, "noise_initial_mas": 0.01,
"noise_scale_mas": 2e-06, "noise_scale_mas": 2e-06,
@ -130,6 +135,13 @@ class VITS(nn.Module):
"use_weight_norm": True, "use_weight_norm": True,
"use_spectral_norm": False, "use_spectral_norm": False,
}, },
"duration_discriminator_params": {
"channels": 192,
"hidden_channels": 192,
"kernel_size": 3,
"dropout_rate": 0.1,
"global_channels": -1,
},
}, },
# loss related # loss related
generator_adv_loss_params: Dict[str, Any] = { generator_adv_loss_params: Dict[str, Any] = {
@ -155,6 +167,7 @@ class VITS(nn.Module):
lambda_feat_match: float = 2.0, lambda_feat_match: float = 2.0,
lambda_dur: float = 1.0, lambda_dur: float = 1.0,
lambda_kl: float = 1.0, lambda_kl: float = 1.0,
lambda_dur_gen: float = 1.0,
cache_generator_outputs: bool = True, cache_generator_outputs: bool = True,
): ):
"""Initialize VITS module. """Initialize VITS module.
@ -194,6 +207,13 @@ class VITS(nn.Module):
# where idim represents #vocabularies and odim represents # where idim represents #vocabularies and odim represents
# the input acoustic feature dimension. # the input acoustic feature dimension.
generator_params.update(vocabs=vocab_size, aux_channels=feature_dim) generator_params.update(vocabs=vocab_size, aux_channels=feature_dim)
self.dur_disc = DurationDiscriminator(
**discriminator_params["duration_discriminator_params"]
)
discriminator_params.pop("duration_discriminator_params")
self.generator = generator_class( self.generator = generator_class(
**generator_params, **generator_params,
) )
@ -216,12 +236,17 @@ class VITS(nn.Module):
) )
self.kl_loss = KLDivergenceLoss() self.kl_loss = KLDivergenceLoss()
# Vits2 duration disc
self.dur_disc_loss = DurationDiscLoss()
self.dur_gen_loss = DurationGenLoss()
# coefficients # coefficients
self.lambda_adv = lambda_adv self.lambda_adv = lambda_adv
self.lambda_mel = lambda_mel self.lambda_mel = lambda_mel
self.lambda_kl = lambda_kl self.lambda_kl = lambda_kl
self.lambda_feat_match = lambda_feat_match self.lambda_feat_match = lambda_feat_match
self.lambda_dur = lambda_dur self.lambda_dur = lambda_dur
self.lambda_dur_gen = lambda_dur_gen
# cache # cache
self.cache_generator_outputs = cache_generator_outputs self.cache_generator_outputs = cache_generator_outputs
@ -349,8 +374,18 @@ class VITS(nn.Module):
self._cache = outs self._cache = outs
# parse outputs # parse outputs
speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs # speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
_, z_p, m_p, logs_p, _, logs_q = outs_ # _, z_p, m_p, logs_p, _, logs_q = outs_
(
speech_hat_,
dur_nll,
attn,
start_idxs,
x_mask,
y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
(hidden_x, logw, logw_),
) = outs
speech_ = get_segments( speech_ = get_segments(
x=speech, x=speech,
start_idxs=start_idxs * self.generator.upsample_factor, start_idxs=start_idxs * self.generator.upsample_factor,
@ -371,17 +406,29 @@ class VITS(nn.Module):
mel_loss, (mel_hat_, mel_) = self.mel_loss( mel_loss, (mel_hat_, mel_) = self.mel_loss(
speech_hat_, speech_, return_mel=True speech_hat_, speech_, return_mel=True
) )
kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask) kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, y_mask)
dur_loss = torch.sum(dur_nll.float()) dur_loss = torch.sum(dur_nll.float())
adv_loss = self.generator_adv_loss(p_hat) adv_loss = self.generator_adv_loss(p_hat)
feat_match_loss = self.feat_match_loss(p_hat, p) feat_match_loss = self.feat_match_loss(p_hat, p)
y_dur_hat_r, y_dur_hat_g = self.dur_disc(hidden_x, x_mask, logw_, logw)
dur_gen_loss = self.dur_gen_loss(y_dur_hat_g)
mel_loss = mel_loss * self.lambda_mel mel_loss = mel_loss * self.lambda_mel
kl_loss = kl_loss * self.lambda_kl kl_loss = kl_loss * self.lambda_kl
dur_loss = dur_loss * self.lambda_dur dur_loss = dur_loss * self.lambda_dur
adv_loss = adv_loss * self.lambda_adv adv_loss = adv_loss * self.lambda_adv
feat_match_loss = feat_match_loss * self.lambda_feat_match feat_match_loss = feat_match_loss * self.lambda_feat_match
loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss dur_gen_loss = dur_gen_loss * self.lambda_dur_gen
loss = (
mel_loss
+ kl_loss
+ dur_loss
+ adv_loss
+ feat_match_loss
+ dur_gen_loss
)
stats = dict( stats = dict(
generator_loss=loss.item(), generator_loss=loss.item(),
@ -390,6 +437,7 @@ class VITS(nn.Module):
generator_dur_loss=dur_loss.item(), generator_dur_loss=dur_loss.item(),
generator_adv_loss=adv_loss.item(), generator_adv_loss=adv_loss.item(),
generator_feat_match_loss=feat_match_loss.item(), generator_feat_match_loss=feat_match_loss.item(),
generator_dur_gen_loss=dur_gen_loss.item(),
) )
if return_sample: if return_sample:
@ -459,8 +507,17 @@ class VITS(nn.Module):
if self.cache_generator_outputs and not reuse_cache: if self.cache_generator_outputs and not reuse_cache:
self._cache = outs self._cache = outs
# parse outputs (
speech_hat_, _, _, start_idxs, *_ = outs speech_hat_,
dur_nll,
attn,
start_idxs,
x_mask,
y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
(hidden_x, logw, logw_),
) = outs
speech_ = get_segments( speech_ = get_segments(
x=speech, x=speech,
start_idxs=start_idxs * self.generator.upsample_factor, start_idxs=start_idxs * self.generator.upsample_factor,
@ -476,6 +533,14 @@ class VITS(nn.Module):
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
loss = real_loss + fake_loss loss = real_loss + fake_loss
# Duration Discriminator
y_dur_hat_r, y_dur_hat_g = self.dur_disc(
hidden_x.detach(), x_mask.detach(), logw_.detach(), logw.detach()
)
with autocast(enabled=False):
dur_loss = self.dur_disc_loss(y_dur_hat_r, y_dur_hat_g)
stats = dict( stats = dict(
discriminator_loss=loss.item(), discriminator_loss=loss.item(),
discriminator_real_loss=real_loss.item(), discriminator_real_loss=real_loss.item(),
@ -486,7 +551,7 @@ class VITS(nn.Module):
if reuse_cache or not self.training: if reuse_cache or not self.training:
self._cache = None self._cache = None
return loss, stats return loss, dur_loss, stats
def inference( def inference(
self, self,

View File

@ -103,9 +103,9 @@ class WaveNet(torch.nn.Module):
# define output layers # define output layers
if self.use_last_conv: if self.use_last_conv:
self.last_conv = torch.nn.Sequential( self.last_conv = torch.nn.Sequential(
torch.nn.ReLU(inplace=True), torch.nn.ReLU(inplace=False),
Conv1d1x1(skip_channels, skip_channels, bias=True), Conv1d1x1(skip_channels, skip_channels, bias=True),
torch.nn.ReLU(inplace=True), torch.nn.ReLU(inplace=False),
Conv1d1x1(skip_channels, out_channels, bias=True), Conv1d1x1(skip_channels, out_channels, bias=True),
) )