mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 18:54:18 +00:00
Add duration discrimination loss
This commit is contained in:
parent
b9fdebaff2
commit
e5c04a216c
@ -20,6 +20,7 @@ from flow import (
|
|||||||
ElementwiseAffineFlow,
|
ElementwiseAffineFlow,
|
||||||
FlipFlow,
|
FlipFlow,
|
||||||
LogFlow,
|
LogFlow,
|
||||||
|
Transpose,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -191,3 +192,68 @@ class StochasticDurationPredictor(torch.nn.Module):
|
|||||||
z0, z1 = z.split(1, 1)
|
z0, z1 = z.split(1, 1)
|
||||||
logw = z0
|
logw = z0
|
||||||
return logw
|
return logw
|
||||||
|
|
||||||
|
|
||||||
|
class DurationPredictor(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_channels: int = 192,
|
||||||
|
output_channels: int = 192,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
dropout_rate: float = 0.5,
|
||||||
|
global_channels: int = -1,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_channels = input_channels
|
||||||
|
self.output_channels = output_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
self.gin_channels = global_channels
|
||||||
|
|
||||||
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||||
|
self.conv_1 = torch.nn.Conv1d(
|
||||||
|
input_channels, output_channels, kernel_size, padding=kernel_size // 2
|
||||||
|
)
|
||||||
|
self.norm_1 = torch.nn.Sequential(
|
||||||
|
Transpose(1, 2),
|
||||||
|
torch.nn.LayerNorm(
|
||||||
|
output_channels,
|
||||||
|
eps=eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
),
|
||||||
|
Transpose(1, 2),
|
||||||
|
)
|
||||||
|
self.conv_2 = torch.nn.Conv1d(
|
||||||
|
output_channels, output_channels, kernel_size, padding=kernel_size // 2
|
||||||
|
)
|
||||||
|
self.norm_2 = torch.nn.Sequential(
|
||||||
|
Transpose(1, 2),
|
||||||
|
torch.nn.LayerNorm(
|
||||||
|
output_channels,
|
||||||
|
eps=eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
),
|
||||||
|
Transpose(1, 2),
|
||||||
|
)
|
||||||
|
self.proj = torch.nn.Conv1d(output_channels, 1, 1)
|
||||||
|
|
||||||
|
if global_channels > 0:
|
||||||
|
self.cond = torch.nn.Conv1d(global_channels, input_channels, 1)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, g=None):
|
||||||
|
x = torch.detach(x)
|
||||||
|
if g is not None:
|
||||||
|
g = torch.detach(g)
|
||||||
|
x = x + self.cond(g)
|
||||||
|
x = self.conv_1(x * x_mask)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.norm_1(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.conv_2(x * x_mask)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.norm_2(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.proj(x * x_mask)
|
||||||
|
return x * x_mask
|
||||||
|
@ -16,7 +16,7 @@ from typing import List, Optional, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from duration_predictor import StochasticDurationPredictor
|
from duration_predictor import DurationPredictor, StochasticDurationPredictor
|
||||||
from hifigan import HiFiGANGenerator
|
from hifigan import HiFiGANGenerator
|
||||||
from posterior_encoder import PosteriorEncoder
|
from posterior_encoder import PosteriorEncoder
|
||||||
from residual_coupling import ResidualAffineCouplingBlock
|
from residual_coupling import ResidualAffineCouplingBlock
|
||||||
@ -71,6 +71,8 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
stochastic_duration_predictor_dropout_rate: float = 0.5,
|
stochastic_duration_predictor_dropout_rate: float = 0.5,
|
||||||
stochastic_duration_predictor_flows: int = 4,
|
stochastic_duration_predictor_flows: int = 4,
|
||||||
stochastic_duration_predictor_dds_conv_layers: int = 3,
|
stochastic_duration_predictor_dds_conv_layers: int = 3,
|
||||||
|
duration_predictor_output_channels: int = 256,
|
||||||
|
use_stochastic_duration_predictor: bool = True,
|
||||||
use_noised_mas: bool = True,
|
use_noised_mas: bool = True,
|
||||||
noise_initial_mas: float = 0.01,
|
noise_initial_mas: float = 0.01,
|
||||||
noise_scale_mas: float = 2e-6,
|
noise_scale_mas: float = 2e-6,
|
||||||
@ -184,14 +186,23 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
use_transformer_in_flows=use_transformer_in_flows,
|
use_transformer_in_flows=use_transformer_in_flows,
|
||||||
)
|
)
|
||||||
# TODO(kan-bayashi): Add deterministic version as an option
|
# TODO(kan-bayashi): Add deterministic version as an option
|
||||||
self.duration_predictor = StochasticDurationPredictor(
|
if use_stochastic_duration_predictor:
|
||||||
channels=hidden_channels,
|
self.duration_predictor = StochasticDurationPredictor(
|
||||||
kernel_size=stochastic_duration_predictor_kernel_size,
|
channels=hidden_channels,
|
||||||
dropout_rate=stochastic_duration_predictor_dropout_rate,
|
kernel_size=stochastic_duration_predictor_kernel_size,
|
||||||
flows=stochastic_duration_predictor_flows,
|
dropout_rate=stochastic_duration_predictor_dropout_rate,
|
||||||
dds_conv_layers=stochastic_duration_predictor_dds_conv_layers,
|
flows=stochastic_duration_predictor_flows,
|
||||||
global_channels=global_channels,
|
dds_conv_layers=stochastic_duration_predictor_dds_conv_layers,
|
||||||
)
|
global_channels=global_channels,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.duration_predictor = DurationPredictor(
|
||||||
|
input_channels=hidden_channels,
|
||||||
|
output_channels=duration_predictor_output_channels,
|
||||||
|
kernel_size=stochastic_duration_predictor_kernel_size,
|
||||||
|
dropout_rate=stochastic_duration_predictor_dropout_rate,
|
||||||
|
global_channels=global_channels,
|
||||||
|
)
|
||||||
|
|
||||||
self.upsample_factor = int(np.prod(decoder_upsample_scales))
|
self.upsample_factor = int(np.prod(decoder_upsample_scales))
|
||||||
|
|
||||||
@ -200,6 +211,7 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
self.noise_current_mas = noise_initial_mas
|
self.noise_current_mas = noise_initial_mas
|
||||||
self.noise_scale_mas = noise_scale_mas
|
self.noise_scale_mas = noise_scale_mas
|
||||||
self.noise_initial_mas = noise_initial_mas
|
self.noise_initial_mas = noise_initial_mas
|
||||||
|
self.use_stochastic_duration_predictor = use_stochastic_duration_predictor
|
||||||
|
|
||||||
self.spks = None
|
self.spks = None
|
||||||
if spks is not None and spks > 1:
|
if spks is not None and spks > 1:
|
||||||
@ -354,8 +366,18 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
|
|
||||||
# forward duration predictor
|
# forward duration predictor
|
||||||
w = attn.sum(2) # (B, 1, T_text)
|
w = attn.sum(2) # (B, 1, T_text)
|
||||||
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
|
|
||||||
dur_nll = dur_nll / torch.sum(x_mask)
|
if self.use_stochastic_duration_predictor:
|
||||||
|
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
|
||||||
|
dur_nll = dur_nll / torch.sum(x_mask)
|
||||||
|
logw = self.duration_predictor(
|
||||||
|
x, x_mask, g=g, inverse=True, noise_scale=1.0
|
||||||
|
)
|
||||||
|
logw_ = torch.log(w + 1e-6) * x_mask
|
||||||
|
else:
|
||||||
|
logw_ = torch.log(w + 1e-6) * x_mask
|
||||||
|
logw = self.dp(x, x_mask, g=g)
|
||||||
|
dur_nll = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||||
|
|
||||||
# expand the length to match with the feature sequence
|
# expand the length to match with the feature sequence
|
||||||
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
||||||
@ -381,6 +403,7 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
x_mask,
|
x_mask,
|
||||||
y_mask,
|
y_mask,
|
||||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||||
|
(x, logw, logw_),
|
||||||
)
|
)
|
||||||
|
|
||||||
def inference(
|
def inference(
|
||||||
|
@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from flow import Transpose
|
||||||
|
|
||||||
|
|
||||||
class HiFiGANGenerator(torch.nn.Module):
|
class HiFiGANGenerator(torch.nn.Module):
|
||||||
@ -931,3 +932,136 @@ class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module):
|
|||||||
msd_outs = self.msd(x)
|
msd_outs = self.msd(x)
|
||||||
mpd_outs = self.mpd(x)
|
mpd_outs = self.mpd(x)
|
||||||
return msd_outs + mpd_outs
|
return msd_outs + mpd_outs
|
||||||
|
|
||||||
|
|
||||||
|
class DurationDiscriminator(torch.nn.Module): # vits2
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int = 192,
|
||||||
|
hidden_channels: int = 192,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
dropout_rate: float = 0.5,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
global_channels: int = -1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.channels = channels
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
self.global_channels = global_channels
|
||||||
|
|
||||||
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||||
|
self.conv_1 = torch.nn.Conv1d(
|
||||||
|
channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||||
|
)
|
||||||
|
self.norm_1 = torch.nn.Sequential(
|
||||||
|
Transpose(1, 2),
|
||||||
|
torch.nn.LayerNorm(
|
||||||
|
hidden_channels,
|
||||||
|
eps=eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
),
|
||||||
|
Transpose(1, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv_2 = torch.nn.Conv1d(
|
||||||
|
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm_2 = torch.nn.Sequential(
|
||||||
|
Transpose(1, 2),
|
||||||
|
torch.nn.LayerNorm(
|
||||||
|
hidden_channels,
|
||||||
|
eps=eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
),
|
||||||
|
Transpose(1, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dur_proj = torch.nn.Conv1d(1, hidden_channels, 1)
|
||||||
|
|
||||||
|
self.pre_out_conv_1 = torch.nn.Conv1d(
|
||||||
|
2 * hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pre_out_norm_1 = torch.nn.Sequential(
|
||||||
|
Transpose(1, 2),
|
||||||
|
torch.nn.LayerNorm(
|
||||||
|
hidden_channels,
|
||||||
|
eps=eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
),
|
||||||
|
Transpose(1, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pre_out_conv_2 = torch.nn.Conv1d(
|
||||||
|
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pre_out_norm_2 = torch.nn.Sequential(
|
||||||
|
Transpose(1, 2),
|
||||||
|
torch.nn.LayerNorm(
|
||||||
|
hidden_channels,
|
||||||
|
eps=eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
),
|
||||||
|
Transpose(1, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
if global_channels > 0:
|
||||||
|
self.cond_layer = torch.nn.Conv1d(global_channels, channels, 1)
|
||||||
|
|
||||||
|
self.output_layer = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(hidden_channels, 1), torch.nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_probability(
|
||||||
|
self, x: torch.Tensor, x_mask: torch.Tensor, dur: torch.Tensor
|
||||||
|
):
|
||||||
|
dur = self.dur_proj(dur)
|
||||||
|
x = torch.cat([x, dur], dim=1)
|
||||||
|
x = self.pre_out_conv_1(x * x_mask)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.pre_out_norm_1(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.pre_out_conv_2(x * x_mask)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.pre_out_norm_2(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = x * x_mask
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
output_prob = self.output_layer(x)
|
||||||
|
return output_prob
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_mask: torch.Tensor,
|
||||||
|
dur_r: torch.Tensor,
|
||||||
|
dur_hat: torch.Tensor,
|
||||||
|
g: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
x = torch.detach(x)
|
||||||
|
if g is not None:
|
||||||
|
g = torch.detach(g)
|
||||||
|
x = x + self.cond_layer(g)
|
||||||
|
|
||||||
|
x = self.conv_1(x * x_mask)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.norm_1(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
x = self.conv_2(x * x_mask)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.norm_2(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
output_probs = []
|
||||||
|
for dur in [dur_r, dur_hat]:
|
||||||
|
output_prob = self.forward_probability(x, x_mask, dur)
|
||||||
|
output_probs.append(output_prob)
|
||||||
|
|
||||||
|
return output_probs
|
||||||
|
@ -333,3 +333,30 @@ class KLDivergenceLossWithoutFlow(torch.nn.Module):
|
|||||||
prior_norm = D.Normal(m_p, torch.exp(logs_p))
|
prior_norm = D.Normal(m_p, torch.exp(logs_p))
|
||||||
loss = D.kl_divergence(posterior_norm, prior_norm).mean()
|
loss = D.kl_divergence(posterior_norm, prior_norm).mean()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class DurationDiscLoss(torch.nn.Module):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
disc_real_outputs: List[torch.Tensor],
|
||||||
|
disc_generated_outputs: List[torch.Tensor],
|
||||||
|
):
|
||||||
|
loss = 0
|
||||||
|
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||||
|
dr = dr.float()
|
||||||
|
dg = dg.float()
|
||||||
|
r_loss = torch.mean((1 - dr) ** 2)
|
||||||
|
g_loss = torch.mean(dg**2)
|
||||||
|
loss += r_loss + g_loss
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class DurationGenLoss(torch.nn.Module):
|
||||||
|
def forward(self, disc_outputs: List[torch.Tensor]):
|
||||||
|
loss = 0
|
||||||
|
for dg in disc_outputs:
|
||||||
|
dg = dg.float()
|
||||||
|
loss += torch.mean((1 - dg) ** 2)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
@ -314,8 +314,10 @@ def train_one_epoch(
|
|||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
optimizer_g: Optimizer,
|
optimizer_g: Optimizer,
|
||||||
optimizer_d: Optimizer,
|
optimizer_d: Optimizer,
|
||||||
|
optimizer_dur: Optimizer,
|
||||||
scheduler_g: LRSchedulerType,
|
scheduler_g: LRSchedulerType,
|
||||||
scheduler_d: LRSchedulerType,
|
scheduler_d: LRSchedulerType,
|
||||||
|
scheduler_dur: LRSchedulerType,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: GradScaler,
|
||||||
@ -402,7 +404,7 @@ def train_one_epoch(
|
|||||||
try:
|
try:
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
# forward discriminator
|
# forward discriminator
|
||||||
loss_d, stats_d = model(
|
loss_d, dur_loss, stats_d = model(
|
||||||
text=tokens,
|
text=tokens,
|
||||||
text_lengths=tokens_lens,
|
text_lengths=tokens_lens,
|
||||||
feats=features,
|
feats=features,
|
||||||
@ -411,6 +413,11 @@ def train_one_epoch(
|
|||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
forward_generator=False,
|
forward_generator=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
optimizer_dur.zero_grad()
|
||||||
|
scaler.scale(dur_loss).backward()
|
||||||
|
scaler.step(optimizer_dur)
|
||||||
|
|
||||||
for k, v in stats_d.items():
|
for k, v in stats_d.items():
|
||||||
loss_info[k] = v * batch_size
|
loss_info[k] = v * batch_size
|
||||||
# update discriminator
|
# update discriminator
|
||||||
@ -597,7 +604,7 @@ def compute_validation_loss(
|
|||||||
loss_info["samples"] = batch_size
|
loss_info["samples"] = batch_size
|
||||||
|
|
||||||
# forward discriminator
|
# forward discriminator
|
||||||
loss_d, stats_d = model(
|
loss_d, dur_loss, stats_d = model(
|
||||||
text=tokens,
|
text=tokens,
|
||||||
text_lengths=tokens_lens,
|
text_lengths=tokens_lens,
|
||||||
feats=features,
|
feats=features,
|
||||||
@ -661,6 +668,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
optimizer_g: torch.optim.Optimizer,
|
optimizer_g: torch.optim.Optimizer,
|
||||||
optimizer_d: torch.optim.Optimizer,
|
optimizer_d: torch.optim.Optimizer,
|
||||||
|
optimizer_dur: torch.optim.Optimizer,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
@ -678,7 +686,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
try:
|
try:
|
||||||
# for discriminator
|
# for discriminator
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
loss_d, stats_d = model(
|
loss_d, dur_loss, stats_d = model(
|
||||||
text=tokens,
|
text=tokens,
|
||||||
text_lengths=tokens_lens,
|
text_lengths=tokens_lens,
|
||||||
feats=features,
|
feats=features,
|
||||||
@ -687,6 +695,10 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
forward_generator=False,
|
forward_generator=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
optimizer_dur.zero_grad()
|
||||||
|
dur_loss.backward()
|
||||||
|
|
||||||
optimizer_d.zero_grad()
|
optimizer_d.zero_grad()
|
||||||
loss_d.backward()
|
loss_d.backward()
|
||||||
# for generator
|
# for generator
|
||||||
@ -760,12 +772,17 @@ def run(rank, world_size, args):
|
|||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
generator = model.generator
|
generator = model.generator
|
||||||
discriminator = model.discriminator
|
discriminator = model.discriminator
|
||||||
|
dur_disc = model.dur_disc
|
||||||
|
|
||||||
num_param_g = sum([p.numel() for p in generator.parameters()])
|
num_param_g = sum([p.numel() for p in generator.parameters()])
|
||||||
logging.info(f"Number of parameters in generator: {num_param_g}")
|
logging.info(f"Number of parameters in generator: {num_param_g}")
|
||||||
num_param_d = sum([p.numel() for p in discriminator.parameters()])
|
num_param_d = sum([p.numel() for p in discriminator.parameters()])
|
||||||
logging.info(f"Number of parameters in discriminator: {num_param_d}")
|
logging.info(f"Number of parameters in discriminator: {num_param_d}")
|
||||||
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
|
num_param_dur = sum([p.numel() for p in dur_disc.parameters()])
|
||||||
|
logging.info(f"Number of parameters in duration discriminator: {num_param_dur}")
|
||||||
|
logging.info(
|
||||||
|
f"Total number of parameters: {num_param_g + num_param_d + num_param_dur}"
|
||||||
|
)
|
||||||
|
|
||||||
assert params.start_epoch > 0, params.start_epoch
|
assert params.start_epoch > 0, params.start_epoch
|
||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
@ -781,9 +798,15 @@ def run(rank, world_size, args):
|
|||||||
optimizer_d = torch.optim.AdamW(
|
optimizer_d = torch.optim.AdamW(
|
||||||
discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
|
discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
|
||||||
)
|
)
|
||||||
|
optimizer_dur = torch.optim.AdamW(
|
||||||
|
dur_disc.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
|
||||||
|
)
|
||||||
|
|
||||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
|
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
|
||||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
|
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
|
||||||
|
scheduler_dur = torch.optim.lr_scheduler.ExponentialLR(
|
||||||
|
optimizer_dur, gamma=0.999875
|
||||||
|
)
|
||||||
|
|
||||||
if checkpoints is not None:
|
if checkpoints is not None:
|
||||||
# load state_dict for optimizers
|
# load state_dict for optimizers
|
||||||
@ -793,6 +816,9 @@ def run(rank, world_size, args):
|
|||||||
if "optimizer_d" in checkpoints:
|
if "optimizer_d" in checkpoints:
|
||||||
logging.info("Loading optimizer_d state dict")
|
logging.info("Loading optimizer_d state dict")
|
||||||
optimizer_d.load_state_dict(checkpoints["optimizer_d"])
|
optimizer_d.load_state_dict(checkpoints["optimizer_d"])
|
||||||
|
if "optimizer_dur" in checkpoints:
|
||||||
|
logging.info("Loading optimizer_dur state dict")
|
||||||
|
optimizer_dur.load_state_dict(checkpoints["optimizer_dur"])
|
||||||
|
|
||||||
# load state_dict for schedulers
|
# load state_dict for schedulers
|
||||||
if "scheduler_g" in checkpoints:
|
if "scheduler_g" in checkpoints:
|
||||||
@ -801,6 +827,9 @@ def run(rank, world_size, args):
|
|||||||
if "scheduler_d" in checkpoints:
|
if "scheduler_d" in checkpoints:
|
||||||
logging.info("Loading scheduler_d state dict")
|
logging.info("Loading scheduler_d state dict")
|
||||||
scheduler_d.load_state_dict(checkpoints["scheduler_d"])
|
scheduler_d.load_state_dict(checkpoints["scheduler_d"])
|
||||||
|
if "scheduler_dur" in checkpoints:
|
||||||
|
logging.info("Loading scheduler_dur state dict")
|
||||||
|
scheduler_dur.load_state_dict(checkpoints["scheduler_dur"])
|
||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
@ -812,7 +841,6 @@ def run(rank, world_size, args):
|
|||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
ljspeech = LJSpeechTtsDataModule(args)
|
ljspeech = LJSpeechTtsDataModule(args)
|
||||||
|
|
||||||
train_cuts = ljspeech.train_cuts()
|
train_cuts = ljspeech.train_cuts()
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
@ -840,6 +868,7 @@ def run(rank, world_size, args):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
optimizer_g=optimizer_g,
|
optimizer_g=optimizer_g,
|
||||||
optimizer_d=optimizer_d,
|
optimizer_d=optimizer_d,
|
||||||
|
optimizer_dur=optimizer_dur,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -865,8 +894,10 @@ def run(rank, world_size, args):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
optimizer_g=optimizer_g,
|
optimizer_g=optimizer_g,
|
||||||
optimizer_d=optimizer_d,
|
optimizer_d=optimizer_d,
|
||||||
|
optimizer_dur=optimizer_dur,
|
||||||
scheduler_g=scheduler_g,
|
scheduler_g=scheduler_g,
|
||||||
scheduler_d=scheduler_d,
|
scheduler_d=scheduler_d,
|
||||||
|
scheduler_dur=scheduler_dur,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
@ -905,6 +936,7 @@ def run(rank, world_size, args):
|
|||||||
# step per epoch
|
# step per epoch
|
||||||
scheduler_g.step()
|
scheduler_g.step()
|
||||||
scheduler_d.step()
|
scheduler_d.step()
|
||||||
|
scheduler_dur.step()
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from generator import VITSGenerator
|
from generator import VITSGenerator
|
||||||
from hifigan import (
|
from hifigan import (
|
||||||
|
DurationDiscriminator,
|
||||||
HiFiGANMultiPeriodDiscriminator,
|
HiFiGANMultiPeriodDiscriminator,
|
||||||
HiFiGANMultiScaleDiscriminator,
|
HiFiGANMultiScaleDiscriminator,
|
||||||
HiFiGANMultiScaleMultiPeriodDiscriminator,
|
HiFiGANMultiScaleMultiPeriodDiscriminator,
|
||||||
@ -19,6 +20,8 @@ from hifigan import (
|
|||||||
)
|
)
|
||||||
from loss import (
|
from loss import (
|
||||||
DiscriminatorAdversarialLoss,
|
DiscriminatorAdversarialLoss,
|
||||||
|
DurationDiscLoss,
|
||||||
|
DurationGenLoss,
|
||||||
FeatureMatchLoss,
|
FeatureMatchLoss,
|
||||||
GeneratorAdversarialLoss,
|
GeneratorAdversarialLoss,
|
||||||
KLDivergenceLoss,
|
KLDivergenceLoss,
|
||||||
@ -87,6 +90,8 @@ class VITS(nn.Module):
|
|||||||
"stochastic_duration_predictor_dropout_rate": 0.5,
|
"stochastic_duration_predictor_dropout_rate": 0.5,
|
||||||
"stochastic_duration_predictor_flows": 4,
|
"stochastic_duration_predictor_flows": 4,
|
||||||
"stochastic_duration_predictor_dds_conv_layers": 3,
|
"stochastic_duration_predictor_dds_conv_layers": 3,
|
||||||
|
"duration_predictor_output_channels": 256,
|
||||||
|
"use_stochastic_duration_predictor": True,
|
||||||
"use_noised_mas": True,
|
"use_noised_mas": True,
|
||||||
"noise_initial_mas": 0.01,
|
"noise_initial_mas": 0.01,
|
||||||
"noise_scale_mas": 2e-06,
|
"noise_scale_mas": 2e-06,
|
||||||
@ -130,6 +135,13 @@ class VITS(nn.Module):
|
|||||||
"use_weight_norm": True,
|
"use_weight_norm": True,
|
||||||
"use_spectral_norm": False,
|
"use_spectral_norm": False,
|
||||||
},
|
},
|
||||||
|
"duration_discriminator_params": {
|
||||||
|
"channels": 192,
|
||||||
|
"hidden_channels": 192,
|
||||||
|
"kernel_size": 3,
|
||||||
|
"dropout_rate": 0.1,
|
||||||
|
"global_channels": -1,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
# loss related
|
# loss related
|
||||||
generator_adv_loss_params: Dict[str, Any] = {
|
generator_adv_loss_params: Dict[str, Any] = {
|
||||||
@ -155,6 +167,7 @@ class VITS(nn.Module):
|
|||||||
lambda_feat_match: float = 2.0,
|
lambda_feat_match: float = 2.0,
|
||||||
lambda_dur: float = 1.0,
|
lambda_dur: float = 1.0,
|
||||||
lambda_kl: float = 1.0,
|
lambda_kl: float = 1.0,
|
||||||
|
lambda_dur_gen: float = 1.0,
|
||||||
cache_generator_outputs: bool = True,
|
cache_generator_outputs: bool = True,
|
||||||
):
|
):
|
||||||
"""Initialize VITS module.
|
"""Initialize VITS module.
|
||||||
@ -194,6 +207,13 @@ class VITS(nn.Module):
|
|||||||
# where idim represents #vocabularies and odim represents
|
# where idim represents #vocabularies and odim represents
|
||||||
# the input acoustic feature dimension.
|
# the input acoustic feature dimension.
|
||||||
generator_params.update(vocabs=vocab_size, aux_channels=feature_dim)
|
generator_params.update(vocabs=vocab_size, aux_channels=feature_dim)
|
||||||
|
|
||||||
|
self.dur_disc = DurationDiscriminator(
|
||||||
|
**discriminator_params["duration_discriminator_params"]
|
||||||
|
)
|
||||||
|
|
||||||
|
discriminator_params.pop("duration_discriminator_params")
|
||||||
|
|
||||||
self.generator = generator_class(
|
self.generator = generator_class(
|
||||||
**generator_params,
|
**generator_params,
|
||||||
)
|
)
|
||||||
@ -216,12 +236,17 @@ class VITS(nn.Module):
|
|||||||
)
|
)
|
||||||
self.kl_loss = KLDivergenceLoss()
|
self.kl_loss = KLDivergenceLoss()
|
||||||
|
|
||||||
|
# Vits2 duration disc
|
||||||
|
self.dur_disc_loss = DurationDiscLoss()
|
||||||
|
self.dur_gen_loss = DurationGenLoss()
|
||||||
|
|
||||||
# coefficients
|
# coefficients
|
||||||
self.lambda_adv = lambda_adv
|
self.lambda_adv = lambda_adv
|
||||||
self.lambda_mel = lambda_mel
|
self.lambda_mel = lambda_mel
|
||||||
self.lambda_kl = lambda_kl
|
self.lambda_kl = lambda_kl
|
||||||
self.lambda_feat_match = lambda_feat_match
|
self.lambda_feat_match = lambda_feat_match
|
||||||
self.lambda_dur = lambda_dur
|
self.lambda_dur = lambda_dur
|
||||||
|
self.lambda_dur_gen = lambda_dur_gen
|
||||||
|
|
||||||
# cache
|
# cache
|
||||||
self.cache_generator_outputs = cache_generator_outputs
|
self.cache_generator_outputs = cache_generator_outputs
|
||||||
@ -349,8 +374,18 @@ class VITS(nn.Module):
|
|||||||
self._cache = outs
|
self._cache = outs
|
||||||
|
|
||||||
# parse outputs
|
# parse outputs
|
||||||
speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
|
# speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
|
||||||
_, z_p, m_p, logs_p, _, logs_q = outs_
|
# _, z_p, m_p, logs_p, _, logs_q = outs_
|
||||||
|
(
|
||||||
|
speech_hat_,
|
||||||
|
dur_nll,
|
||||||
|
attn,
|
||||||
|
start_idxs,
|
||||||
|
x_mask,
|
||||||
|
y_mask,
|
||||||
|
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||||
|
(hidden_x, logw, logw_),
|
||||||
|
) = outs
|
||||||
speech_ = get_segments(
|
speech_ = get_segments(
|
||||||
x=speech,
|
x=speech,
|
||||||
start_idxs=start_idxs * self.generator.upsample_factor,
|
start_idxs=start_idxs * self.generator.upsample_factor,
|
||||||
@ -371,17 +406,29 @@ class VITS(nn.Module):
|
|||||||
mel_loss, (mel_hat_, mel_) = self.mel_loss(
|
mel_loss, (mel_hat_, mel_) = self.mel_loss(
|
||||||
speech_hat_, speech_, return_mel=True
|
speech_hat_, speech_, return_mel=True
|
||||||
)
|
)
|
||||||
kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
|
kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, y_mask)
|
||||||
dur_loss = torch.sum(dur_nll.float())
|
dur_loss = torch.sum(dur_nll.float())
|
||||||
adv_loss = self.generator_adv_loss(p_hat)
|
adv_loss = self.generator_adv_loss(p_hat)
|
||||||
feat_match_loss = self.feat_match_loss(p_hat, p)
|
feat_match_loss = self.feat_match_loss(p_hat, p)
|
||||||
|
|
||||||
|
y_dur_hat_r, y_dur_hat_g = self.dur_disc(hidden_x, x_mask, logw_, logw)
|
||||||
|
dur_gen_loss = self.dur_gen_loss(y_dur_hat_g)
|
||||||
|
|
||||||
mel_loss = mel_loss * self.lambda_mel
|
mel_loss = mel_loss * self.lambda_mel
|
||||||
kl_loss = kl_loss * self.lambda_kl
|
kl_loss = kl_loss * self.lambda_kl
|
||||||
dur_loss = dur_loss * self.lambda_dur
|
dur_loss = dur_loss * self.lambda_dur
|
||||||
adv_loss = adv_loss * self.lambda_adv
|
adv_loss = adv_loss * self.lambda_adv
|
||||||
feat_match_loss = feat_match_loss * self.lambda_feat_match
|
feat_match_loss = feat_match_loss * self.lambda_feat_match
|
||||||
loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss
|
dur_gen_loss = dur_gen_loss * self.lambda_dur_gen
|
||||||
|
|
||||||
|
loss = (
|
||||||
|
mel_loss
|
||||||
|
+ kl_loss
|
||||||
|
+ dur_loss
|
||||||
|
+ adv_loss
|
||||||
|
+ feat_match_loss
|
||||||
|
+ dur_gen_loss
|
||||||
|
)
|
||||||
|
|
||||||
stats = dict(
|
stats = dict(
|
||||||
generator_loss=loss.item(),
|
generator_loss=loss.item(),
|
||||||
@ -390,6 +437,7 @@ class VITS(nn.Module):
|
|||||||
generator_dur_loss=dur_loss.item(),
|
generator_dur_loss=dur_loss.item(),
|
||||||
generator_adv_loss=adv_loss.item(),
|
generator_adv_loss=adv_loss.item(),
|
||||||
generator_feat_match_loss=feat_match_loss.item(),
|
generator_feat_match_loss=feat_match_loss.item(),
|
||||||
|
generator_dur_gen_loss=dur_gen_loss.item(),
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_sample:
|
if return_sample:
|
||||||
@ -459,8 +507,17 @@ class VITS(nn.Module):
|
|||||||
if self.cache_generator_outputs and not reuse_cache:
|
if self.cache_generator_outputs and not reuse_cache:
|
||||||
self._cache = outs
|
self._cache = outs
|
||||||
|
|
||||||
# parse outputs
|
(
|
||||||
speech_hat_, _, _, start_idxs, *_ = outs
|
speech_hat_,
|
||||||
|
dur_nll,
|
||||||
|
attn,
|
||||||
|
start_idxs,
|
||||||
|
x_mask,
|
||||||
|
y_mask,
|
||||||
|
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||||
|
(hidden_x, logw, logw_),
|
||||||
|
) = outs
|
||||||
|
|
||||||
speech_ = get_segments(
|
speech_ = get_segments(
|
||||||
x=speech,
|
x=speech,
|
||||||
start_idxs=start_idxs * self.generator.upsample_factor,
|
start_idxs=start_idxs * self.generator.upsample_factor,
|
||||||
@ -476,6 +533,14 @@ class VITS(nn.Module):
|
|||||||
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
|
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
|
||||||
loss = real_loss + fake_loss
|
loss = real_loss + fake_loss
|
||||||
|
|
||||||
|
# Duration Discriminator
|
||||||
|
y_dur_hat_r, y_dur_hat_g = self.dur_disc(
|
||||||
|
hidden_x.detach(), x_mask.detach(), logw_.detach(), logw.detach()
|
||||||
|
)
|
||||||
|
|
||||||
|
with autocast(enabled=False):
|
||||||
|
dur_loss = self.dur_disc_loss(y_dur_hat_r, y_dur_hat_g)
|
||||||
|
|
||||||
stats = dict(
|
stats = dict(
|
||||||
discriminator_loss=loss.item(),
|
discriminator_loss=loss.item(),
|
||||||
discriminator_real_loss=real_loss.item(),
|
discriminator_real_loss=real_loss.item(),
|
||||||
@ -486,7 +551,7 @@ class VITS(nn.Module):
|
|||||||
if reuse_cache or not self.training:
|
if reuse_cache or not self.training:
|
||||||
self._cache = None
|
self._cache = None
|
||||||
|
|
||||||
return loss, stats
|
return loss, dur_loss, stats
|
||||||
|
|
||||||
def inference(
|
def inference(
|
||||||
self,
|
self,
|
||||||
|
@ -103,9 +103,9 @@ class WaveNet(torch.nn.Module):
|
|||||||
# define output layers
|
# define output layers
|
||||||
if self.use_last_conv:
|
if self.use_last_conv:
|
||||||
self.last_conv = torch.nn.Sequential(
|
self.last_conv = torch.nn.Sequential(
|
||||||
torch.nn.ReLU(inplace=True),
|
torch.nn.ReLU(inplace=False),
|
||||||
Conv1d1x1(skip_channels, skip_channels, bias=True),
|
Conv1d1x1(skip_channels, skip_channels, bias=True),
|
||||||
torch.nn.ReLU(inplace=True),
|
torch.nn.ReLU(inplace=False),
|
||||||
Conv1d1x1(skip_channels, out_channels, bias=True),
|
Conv1d1x1(skip_channels, out_channels, bias=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user