mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
Add speaker embedding in text encoder
This commit is contained in:
parent
e5c04a216c
commit
a4e4f8080a
@ -77,6 +77,7 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
noise_initial_mas: float = 0.01,
|
noise_initial_mas: float = 0.01,
|
||||||
noise_scale_mas: float = 2e-6,
|
noise_scale_mas: float = 2e-6,
|
||||||
use_transformer_in_flows: bool = True,
|
use_transformer_in_flows: bool = True,
|
||||||
|
use_spk_conditioned_txt_encoder: bool = False,
|
||||||
):
|
):
|
||||||
"""Initialize VITS generator module.
|
"""Initialize VITS generator module.
|
||||||
|
|
||||||
@ -138,6 +139,13 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.segment_size = segment_size
|
self.segment_size = segment_size
|
||||||
|
self.use_spk_conditioned_txt_encoder = use_spk_conditioned_txt_encoder
|
||||||
|
|
||||||
|
if self.use_spk_conditioned_txt_encoder and global_channels > 0:
|
||||||
|
self.text_encoder_global_channels = global_channels
|
||||||
|
else:
|
||||||
|
self.text_encoder_global_channels = 0
|
||||||
|
|
||||||
self.text_encoder = TextEncoder(
|
self.text_encoder = TextEncoder(
|
||||||
vocabs=vocabs,
|
vocabs=vocabs,
|
||||||
d_model=hidden_channels,
|
d_model=hidden_channels,
|
||||||
@ -146,6 +154,7 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
cnn_module_kernel=text_encoder_cnn_module_kernel,
|
cnn_module_kernel=text_encoder_cnn_module_kernel,
|
||||||
num_layers=text_encoder_blocks,
|
num_layers=text_encoder_blocks,
|
||||||
dropout=text_encoder_dropout_rate,
|
dropout=text_encoder_dropout_rate,
|
||||||
|
global_channels=self.text_encoder_global_channels,
|
||||||
)
|
)
|
||||||
self.decoder = HiFiGANGenerator(
|
self.decoder = HiFiGANGenerator(
|
||||||
in_channels=hidden_channels,
|
in_channels=hidden_channels,
|
||||||
@ -286,9 +295,6 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
- Tensor: Posterior encoder projected scale (B, H, T_feats).
|
- Tensor: Posterior encoder projected scale (B, H, T_feats).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# forward text encoder
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
|
|
||||||
|
|
||||||
# calculate global conditioning
|
# calculate global conditioning
|
||||||
g = None
|
g = None
|
||||||
if self.spks is not None:
|
if self.spks is not None:
|
||||||
@ -309,6 +315,9 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
g = g + g_
|
g = g + g_
|
||||||
|
|
||||||
|
# forward text encoder
|
||||||
|
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths, g)
|
||||||
|
|
||||||
# forward posterior encoder
|
# forward posterior encoder
|
||||||
z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
|
z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
|
||||||
|
|
||||||
@ -376,7 +385,7 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
logw_ = torch.log(w + 1e-6) * x_mask
|
logw_ = torch.log(w + 1e-6) * x_mask
|
||||||
else:
|
else:
|
||||||
logw_ = torch.log(w + 1e-6) * x_mask
|
logw_ = torch.log(w + 1e-6) * x_mask
|
||||||
logw = self.dp(x, x_mask, g=g)
|
logw = self.duration_predictor(x, x_mask, g=g)
|
||||||
dur_nll = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
|
dur_nll = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||||
|
|
||||||
# expand the length to match with the feature sequence
|
# expand the length to match with the feature sequence
|
||||||
|
@ -193,7 +193,7 @@ class FeatureMatchLoss(torch.nn.Module):
|
|||||||
from generator's outputs.
|
from generator's outputs.
|
||||||
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
||||||
discriminator outputs or list of discriminator outputs calcuated
|
discriminator outputs or list of discriminator outputs calcuated
|
||||||
from groundtruth..
|
from groundtruth.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: Feature matching loss value.
|
Tensor: Feature matching loss value.
|
||||||
@ -333,30 +333,3 @@ class KLDivergenceLossWithoutFlow(torch.nn.Module):
|
|||||||
prior_norm = D.Normal(m_p, torch.exp(logs_p))
|
prior_norm = D.Normal(m_p, torch.exp(logs_p))
|
||||||
loss = D.kl_divergence(posterior_norm, prior_norm).mean()
|
loss = D.kl_divergence(posterior_norm, prior_norm).mean()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class DurationDiscLoss(torch.nn.Module):
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
disc_real_outputs: List[torch.Tensor],
|
|
||||||
disc_generated_outputs: List[torch.Tensor],
|
|
||||||
):
|
|
||||||
loss = 0
|
|
||||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
|
||||||
dr = dr.float()
|
|
||||||
dg = dg.float()
|
|
||||||
r_loss = torch.mean((1 - dr) ** 2)
|
|
||||||
g_loss = torch.mean(dg**2)
|
|
||||||
loss += r_loss + g_loss
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class DurationGenLoss(torch.nn.Module):
|
|
||||||
def forward(self, disc_outputs: List[torch.Tensor]):
|
|
||||||
loss = 0
|
|
||||||
for dg in disc_outputs:
|
|
||||||
dg = dg.float()
|
|
||||||
loss += torch.mean((1 - dg) ** 2)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
@ -49,6 +49,8 @@ class TextEncoder(torch.nn.Module):
|
|||||||
cnn_module_kernel: int = 5,
|
cnn_module_kernel: int = 5,
|
||||||
num_layers: int = 6,
|
num_layers: int = 6,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
|
global_channels: int = 0,
|
||||||
|
cond_layer_idx: int = 2,
|
||||||
):
|
):
|
||||||
"""Initialize TextEncoder module.
|
"""Initialize TextEncoder module.
|
||||||
|
|
||||||
@ -63,6 +65,8 @@ class TextEncoder(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
|
self.global_channels = global_channels
|
||||||
|
self.cond_layer_idx = cond_layer_idx
|
||||||
|
|
||||||
# define modules
|
# define modules
|
||||||
self.emb = torch.nn.Embedding(vocabs, d_model)
|
self.emb = torch.nn.Embedding(vocabs, d_model)
|
||||||
@ -76,14 +80,14 @@ class TextEncoder(torch.nn.Module):
|
|||||||
cnn_module_kernel=cnn_module_kernel,
|
cnn_module_kernel=cnn_module_kernel,
|
||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
|
global_channels=global_channels,
|
||||||
|
cond_layer_idx=cond_layer_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1)
|
self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
|
||||||
x: torch.Tensor,
|
|
||||||
x_lengths: torch.Tensor,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""Calculate forward propagation.
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
@ -107,7 +111,7 @@ class TextEncoder(torch.nn.Module):
|
|||||||
pad_mask = make_pad_mask(x_lengths)
|
pad_mask = make_pad_mask(x_lengths)
|
||||||
|
|
||||||
# encoder assume the channel last (B, T_text, embed_dim)
|
# encoder assume the channel last (B, T_text, embed_dim)
|
||||||
x = self.encoder(x, key_padding_mask=pad_mask)
|
x = self.encoder(x, key_padding_mask=pad_mask, g=g)
|
||||||
|
|
||||||
# convert the channel first (B, embed_dim, T_text)
|
# convert the channel first (B, embed_dim, T_text)
|
||||||
x = x.transpose(1, 2)
|
x = x.transpose(1, 2)
|
||||||
@ -137,11 +141,18 @@ class Transformer(nn.Module):
|
|||||||
cnn_module_kernel: int = 5,
|
cnn_module_kernel: int = 5,
|
||||||
num_layers: int = 6,
|
num_layers: int = 6,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
|
global_channels: int = 0,
|
||||||
|
cond_layer_idx: int = 2,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
|
self.global_channels = global_channels
|
||||||
|
|
||||||
|
speaker_embedder = None
|
||||||
|
if self.global_channels != 0:
|
||||||
|
speaker_embedder = nn.Linear(self.global_channels, self.d_model)
|
||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
|
|
||||||
@ -151,12 +162,13 @@ class Transformer(nn.Module):
|
|||||||
dim_feedforward=dim_feedforward,
|
dim_feedforward=dim_feedforward,
|
||||||
cnn_module_kernel=cnn_module_kernel,
|
cnn_module_kernel=cnn_module_kernel,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
|
speaker_embedder=speaker_embedder,
|
||||||
)
|
)
|
||||||
self.encoder = TransformerEncoder(encoder_layer, num_layers)
|
self.encoder = TransformerEncoder(encoder_layer, num_layers, cond_layer_idx)
|
||||||
self.after_norm = nn.LayerNorm(d_model)
|
self.after_norm = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: Tensor, key_padding_mask: Tensor
|
self, x: Tensor, key_padding_mask: Tensor, g: Optional[Tensor] = None
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -169,7 +181,9 @@ class Transformer(nn.Module):
|
|||||||
x, pos_emb = self.encoder_pos(x)
|
x, pos_emb = self.encoder_pos(x)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C)
|
x = self.encoder(
|
||||||
|
x, pos_emb, key_padding_mask=key_padding_mask, g=g
|
||||||
|
) # (T, N, C)
|
||||||
|
|
||||||
x = self.after_norm(x)
|
x = self.after_norm(x)
|
||||||
|
|
||||||
@ -195,6 +209,7 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
dim_feedforward: int,
|
dim_feedforward: int,
|
||||||
cnn_module_kernel: int,
|
cnn_module_kernel: int,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
|
speaker_embedder: Optional[nn.Module] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(TransformerEncoderLayer, self).__init__()
|
super(TransformerEncoderLayer, self).__init__()
|
||||||
|
|
||||||
@ -227,11 +242,15 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
self.ff_scale = 0.5
|
self.ff_scale = 0.5
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.speaker_embedder = speaker_embedder
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
g: Optional[Tensor] = None,
|
||||||
|
apply_speaker_embedding: bool = False,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Pass the input through the transformer encoder layer.
|
Pass the input through the transformer encoder layer.
|
||||||
@ -241,6 +260,11 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
|
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
|
||||||
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
|
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
|
||||||
"""
|
"""
|
||||||
|
if g is not None and apply_speaker_embedding:
|
||||||
|
g = self.speaker_embedder(g.transpose(1, 2))
|
||||||
|
g = g.transpose(1, 2)
|
||||||
|
src = src + g # * src_mask
|
||||||
|
|
||||||
# macaron style feed-forward module
|
# macaron style feed-forward module
|
||||||
src = src + self.ff_scale * self.dropout(
|
src = src + self.ff_scale * self.dropout(
|
||||||
self.feed_forward_macaron(self.norm_ff_macaron(src))
|
self.feed_forward_macaron(self.norm_ff_macaron(src))
|
||||||
@ -273,19 +297,23 @@ class TransformerEncoder(nn.Module):
|
|||||||
num_layers: the number of sub-encoder-layers in the encoder.
|
num_layers: the number of sub-encoder-layers in the encoder.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
def __init__(
|
||||||
|
self, encoder_layer: nn.Module, num_layers: int, cond_layer_idx: int = 0
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||||
)
|
)
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
self.cond_layer_idx = cond_layer_idx
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
g: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
@ -295,12 +323,17 @@ class TransformerEncoder(nn.Module):
|
|||||||
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
|
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
|
||||||
"""
|
"""
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
for layer_index, mod in enumerate(self.layers):
|
for layer_index, mod in enumerate(self.layers):
|
||||||
|
apply_speaker_embedding = False
|
||||||
|
if layer_index == self.cond_layer_idx:
|
||||||
|
apply_speaker_embedding = True
|
||||||
|
|
||||||
output = mod(
|
output = mod(
|
||||||
output,
|
output,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
key_padding_mask=key_padding_mask,
|
key_padding_mask=key_padding_mask,
|
||||||
|
g=g,
|
||||||
|
apply_speaker_embedding=apply_speaker_embedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
@ -375,8 +375,10 @@ def train_one_epoch(
|
|||||||
params=params,
|
params=params,
|
||||||
optimizer_g=optimizer_g,
|
optimizer_g=optimizer_g,
|
||||||
optimizer_d=optimizer_d,
|
optimizer_d=optimizer_d,
|
||||||
|
optimizer_dur=optimizer_dur,
|
||||||
scheduler_g=scheduler_g,
|
scheduler_g=scheduler_g,
|
||||||
scheduler_d=scheduler_d,
|
scheduler_d=scheduler_d,
|
||||||
|
scheduler_dur=scheduler_dur,
|
||||||
sampler=train_dl.sampler,
|
sampler=train_dl.sampler,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
rank=0,
|
rank=0,
|
||||||
@ -393,15 +395,42 @@ def train_one_epoch(
|
|||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
loss_info["samples"] = batch_size
|
loss_info["samples"] = batch_size
|
||||||
|
|
||||||
if model.module.generator.use_noised_mas:
|
if isinstance(model, DDP):
|
||||||
# MAS with Gaussian Noise
|
if model.module.generator.use_noised_mas:
|
||||||
model.module.generator.noise_current_mas = max(
|
# MAS with Gaussian Noise
|
||||||
model.module.generator.noise_initial_mas
|
model.module.generator.noise_current_mas = max(
|
||||||
- model.module.generator.noise_scale_mas * params.batch_idx_train,
|
model.module.generator.noise_initial_mas
|
||||||
0.0,
|
- model.module.generator.noise_scale_mas * params.batch_idx_train,
|
||||||
)
|
0.0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if model.generator.use_noised_mas:
|
||||||
|
# MAS with Gaussian Noise
|
||||||
|
model.generator.noise_current_mas = max(
|
||||||
|
model.generator.noise_initial_mas
|
||||||
|
- model.generator.noise_scale_mas * params.batch_idx_train,
|
||||||
|
0.0,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
with autocast(enabled=params.use_fp16):
|
||||||
|
# forward duration discriminator
|
||||||
|
dur_loss, stats_dur = model(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
feats=features,
|
||||||
|
feats_lengths=features_lens,
|
||||||
|
speech=audio,
|
||||||
|
speech_lengths=audio_lens,
|
||||||
|
forward_type="duration_discriminator",
|
||||||
|
)
|
||||||
|
for k, v in stats_dur.items():
|
||||||
|
loss_info[k] = v * batch_size
|
||||||
|
|
||||||
|
optimizer_dur.zero_grad()
|
||||||
|
scaler.scale(dur_loss).backward()
|
||||||
|
scaler.step(optimizer_dur)
|
||||||
|
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
# forward discriminator
|
# forward discriminator
|
||||||
loss_d, dur_loss, stats_d = model(
|
loss_d, dur_loss, stats_d = model(
|
||||||
@ -411,13 +440,9 @@ def train_one_epoch(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
forward_generator=False,
|
forward_type="discriminator",
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer_dur.zero_grad()
|
|
||||||
scaler.scale(dur_loss).backward()
|
|
||||||
scaler.step(optimizer_dur)
|
|
||||||
|
|
||||||
for k, v in stats_d.items():
|
for k, v in stats_d.items():
|
||||||
loss_info[k] = v * batch_size
|
loss_info[k] = v * batch_size
|
||||||
# update discriminator
|
# update discriminator
|
||||||
@ -434,7 +459,7 @@ def train_one_epoch(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
forward_generator=True,
|
forward_type="generator",
|
||||||
return_sample=params.batch_idx_train % params.log_interval == 0,
|
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||||
)
|
)
|
||||||
for k, v in stats_g.items():
|
for k, v in stats_g.items():
|
||||||
@ -603,15 +628,29 @@ def compute_validation_loss(
|
|||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
loss_info["samples"] = batch_size
|
loss_info["samples"] = batch_size
|
||||||
|
|
||||||
# forward discriminator
|
# forward duration discriminator
|
||||||
loss_d, dur_loss, stats_d = model(
|
loss_dur, stats_dur = model(
|
||||||
text=tokens,
|
text=tokens,
|
||||||
text_lengths=tokens_lens,
|
text_lengths=tokens_lens,
|
||||||
feats=features,
|
feats=features,
|
||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
forward_generator=False,
|
forward_type="duration_discriminator",
|
||||||
|
)
|
||||||
|
assert loss_dur.requires_grad is False
|
||||||
|
for k, v in stats_dur.items():
|
||||||
|
loss_info[k] = v * batch_size
|
||||||
|
|
||||||
|
# forward discriminator
|
||||||
|
loss_d, stats_d = model(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
feats=features,
|
||||||
|
feats_lengths=features_lens,
|
||||||
|
speech=audio,
|
||||||
|
speech_lengths=audio_lens,
|
||||||
|
forward_type="discriminator",
|
||||||
)
|
)
|
||||||
assert loss_d.requires_grad is False
|
assert loss_d.requires_grad is False
|
||||||
for k, v in stats_d.items():
|
for k, v in stats_d.items():
|
||||||
@ -625,7 +664,7 @@ def compute_validation_loss(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
forward_generator=True,
|
forward_type="generator",
|
||||||
)
|
)
|
||||||
assert loss_g.requires_grad is False
|
assert loss_g.requires_grad is False
|
||||||
for k, v in stats_g.items():
|
for k, v in stats_g.items():
|
||||||
@ -684,21 +723,33 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
batch, tokenizer, device
|
batch, tokenizer, device
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
# for discriminator
|
# for duration discriminator
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
loss_d, dur_loss, stats_d = model(
|
dur_loss, stats_dur = model(
|
||||||
text=tokens,
|
text=tokens,
|
||||||
text_lengths=tokens_lens,
|
text_lengths=tokens_lens,
|
||||||
feats=features,
|
feats=features,
|
||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
forward_generator=False,
|
forward_type="duration_discriminator",
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer_dur.zero_grad()
|
optimizer_dur.zero_grad()
|
||||||
dur_loss.backward()
|
dur_loss.backward()
|
||||||
|
|
||||||
|
# for discriminator
|
||||||
|
with autocast(enabled=params.use_fp16):
|
||||||
|
loss_d, stats_d = model(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
feats=features,
|
||||||
|
feats_lengths=features_lens,
|
||||||
|
speech=audio,
|
||||||
|
speech_lengths=audio_lens,
|
||||||
|
forward_type="discriminator",
|
||||||
|
)
|
||||||
|
|
||||||
optimizer_d.zero_grad()
|
optimizer_d.zero_grad()
|
||||||
loss_d.backward()
|
loss_d.backward()
|
||||||
# for generator
|
# for generator
|
||||||
@ -710,7 +761,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
forward_generator=True,
|
forward_type="generator",
|
||||||
)
|
)
|
||||||
optimizer_g.zero_grad()
|
optimizer_g.zero_grad()
|
||||||
loss_g.backward()
|
loss_g.backward()
|
||||||
@ -918,8 +969,10 @@ def run(rank, world_size, args):
|
|||||||
model=model,
|
model=model,
|
||||||
optimizer_g=optimizer_g,
|
optimizer_g=optimizer_g,
|
||||||
optimizer_d=optimizer_d,
|
optimizer_d=optimizer_d,
|
||||||
|
optimizer_dur=optimizer_dur,
|
||||||
scheduler_g=scheduler_g,
|
scheduler_g=scheduler_g,
|
||||||
scheduler_d=scheduler_d,
|
scheduler_d=scheduler_d,
|
||||||
|
scheduler_dur=scheduler_dur,
|
||||||
sampler=train_dl.sampler,
|
sampler=train_dl.sampler,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
|
@ -203,8 +203,10 @@ def save_checkpoint(
|
|||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
optimizer_g: Optional[Optimizer] = None,
|
optimizer_g: Optional[Optimizer] = None,
|
||||||
optimizer_d: Optional[Optimizer] = None,
|
optimizer_d: Optional[Optimizer] = None,
|
||||||
|
optimizer_dur: Optional[Optimizer] = None,
|
||||||
scheduler_g: Optional[LRSchedulerType] = None,
|
scheduler_g: Optional[LRSchedulerType] = None,
|
||||||
scheduler_d: Optional[LRSchedulerType] = None,
|
scheduler_d: Optional[LRSchedulerType] = None,
|
||||||
|
scheduler_dur: Optional[LRSchedulerType] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional[GradScaler] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
@ -251,7 +253,13 @@ def save_checkpoint(
|
|||||||
"model": model.state_dict(),
|
"model": model.state_dict(),
|
||||||
"optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None,
|
"optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None,
|
||||||
"optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None,
|
"optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None,
|
||||||
|
"optimizer_dur": optimizer_dur.state_dict()
|
||||||
|
if optimizer_d is not None
|
||||||
|
else None,
|
||||||
"scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None,
|
"scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None,
|
||||||
|
"scheduler_dur": scheduler_dur.state_dict()
|
||||||
|
if scheduler_d is not None
|
||||||
|
else None,
|
||||||
"scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None,
|
"scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None,
|
||||||
"grad_scaler": scaler.state_dict() if scaler is not None else None,
|
"grad_scaler": scaler.state_dict() if scaler is not None else None,
|
||||||
"sampler": sampler.state_dict() if sampler is not None else None,
|
"sampler": sampler.state_dict() if sampler is not None else None,
|
||||||
|
@ -20,8 +20,6 @@ from hifigan import (
|
|||||||
)
|
)
|
||||||
from loss import (
|
from loss import (
|
||||||
DiscriminatorAdversarialLoss,
|
DiscriminatorAdversarialLoss,
|
||||||
DurationDiscLoss,
|
|
||||||
DurationGenLoss,
|
|
||||||
FeatureMatchLoss,
|
FeatureMatchLoss,
|
||||||
GeneratorAdversarialLoss,
|
GeneratorAdversarialLoss,
|
||||||
KLDivergenceLoss,
|
KLDivergenceLoss,
|
||||||
@ -236,10 +234,6 @@ class VITS(nn.Module):
|
|||||||
)
|
)
|
||||||
self.kl_loss = KLDivergenceLoss()
|
self.kl_loss = KLDivergenceLoss()
|
||||||
|
|
||||||
# Vits2 duration disc
|
|
||||||
self.dur_disc_loss = DurationDiscLoss()
|
|
||||||
self.dur_gen_loss = DurationGenLoss()
|
|
||||||
|
|
||||||
# coefficients
|
# coefficients
|
||||||
self.lambda_adv = lambda_adv
|
self.lambda_adv = lambda_adv
|
||||||
self.lambda_mel = lambda_mel
|
self.lambda_mel = lambda_mel
|
||||||
@ -273,7 +267,7 @@ class VITS(nn.Module):
|
|||||||
sids: Optional[torch.Tensor] = None,
|
sids: Optional[torch.Tensor] = None,
|
||||||
spembs: Optional[torch.Tensor] = None,
|
spembs: Optional[torch.Tensor] = None,
|
||||||
lids: Optional[torch.Tensor] = None,
|
lids: Optional[torch.Tensor] = None,
|
||||||
forward_generator: bool = True,
|
forward_type: str = "generator",
|
||||||
) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||||
"""Perform generator forward.
|
"""Perform generator forward.
|
||||||
|
|
||||||
@ -287,13 +281,13 @@ class VITS(nn.Module):
|
|||||||
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||||
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||||
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||||
forward_generator (bool): Whether to forward generator.
|
forward_type (str): Which type of forward to do
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- loss (Tensor): Loss scalar tensor.
|
- loss (Tensor): Loss scalar tensor.
|
||||||
- stats (Dict[str, float]): Statistics to be monitored.
|
- stats (Dict[str, float]): Statistics to be monitored.
|
||||||
"""
|
"""
|
||||||
if forward_generator:
|
if forward_type == "generator":
|
||||||
return self._forward_generator(
|
return self._forward_generator(
|
||||||
text=text,
|
text=text,
|
||||||
text_lengths=text_lengths,
|
text_lengths=text_lengths,
|
||||||
@ -306,7 +300,7 @@ class VITS(nn.Module):
|
|||||||
spembs=spembs,
|
spembs=spembs,
|
||||||
lids=lids,
|
lids=lids,
|
||||||
)
|
)
|
||||||
else:
|
elif forward_type == "discriminator":
|
||||||
return self._forward_discrminator(
|
return self._forward_discrminator(
|
||||||
text=text,
|
text=text,
|
||||||
text_lengths=text_lengths,
|
text_lengths=text_lengths,
|
||||||
@ -318,6 +312,20 @@ class VITS(nn.Module):
|
|||||||
spembs=spembs,
|
spembs=spembs,
|
||||||
lids=lids,
|
lids=lids,
|
||||||
)
|
)
|
||||||
|
elif forward_type == "duration_discriminator":
|
||||||
|
return self._forward_discrminator_duration(
|
||||||
|
text=text,
|
||||||
|
text_lengths=text_lengths,
|
||||||
|
feats=feats,
|
||||||
|
feats_lengths=feats_lengths,
|
||||||
|
speech=speech,
|
||||||
|
speech_lengths=speech_lengths,
|
||||||
|
sids=sids,
|
||||||
|
spembs=spembs,
|
||||||
|
lids=lids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Forward type {forward_type} does not exist")
|
||||||
|
|
||||||
def _forward_generator(
|
def _forward_generator(
|
||||||
self,
|
self,
|
||||||
@ -374,8 +382,6 @@ class VITS(nn.Module):
|
|||||||
self._cache = outs
|
self._cache = outs
|
||||||
|
|
||||||
# parse outputs
|
# parse outputs
|
||||||
# speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
|
|
||||||
# _, z_p, m_p, logs_p, _, logs_q = outs_
|
|
||||||
(
|
(
|
||||||
speech_hat_,
|
speech_hat_,
|
||||||
dur_nll,
|
dur_nll,
|
||||||
@ -412,7 +418,7 @@ class VITS(nn.Module):
|
|||||||
feat_match_loss = self.feat_match_loss(p_hat, p)
|
feat_match_loss = self.feat_match_loss(p_hat, p)
|
||||||
|
|
||||||
y_dur_hat_r, y_dur_hat_g = self.dur_disc(hidden_x, x_mask, logw_, logw)
|
y_dur_hat_r, y_dur_hat_g = self.dur_disc(hidden_x, x_mask, logw_, logw)
|
||||||
dur_gen_loss = self.dur_gen_loss(y_dur_hat_g)
|
dur_gen_loss = self.generator_adv_loss(y_dur_hat_g)
|
||||||
|
|
||||||
mel_loss = mel_loss * self.lambda_mel
|
mel_loss = mel_loss * self.lambda_mel
|
||||||
kl_loss = kl_loss * self.lambda_kl
|
kl_loss = kl_loss * self.lambda_kl
|
||||||
@ -514,8 +520,8 @@ class VITS(nn.Module):
|
|||||||
start_idxs,
|
start_idxs,
|
||||||
x_mask,
|
x_mask,
|
||||||
y_mask,
|
y_mask,
|
||||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
_,
|
||||||
(hidden_x, logw, logw_),
|
_,
|
||||||
) = outs
|
) = outs
|
||||||
|
|
||||||
speech_ = get_segments(
|
speech_ = get_segments(
|
||||||
@ -533,14 +539,6 @@ class VITS(nn.Module):
|
|||||||
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
|
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
|
||||||
loss = real_loss + fake_loss
|
loss = real_loss + fake_loss
|
||||||
|
|
||||||
# Duration Discriminator
|
|
||||||
y_dur_hat_r, y_dur_hat_g = self.dur_disc(
|
|
||||||
hidden_x.detach(), x_mask.detach(), logw_.detach(), logw.detach()
|
|
||||||
)
|
|
||||||
|
|
||||||
with autocast(enabled=False):
|
|
||||||
dur_loss = self.dur_disc_loss(y_dur_hat_r, y_dur_hat_g)
|
|
||||||
|
|
||||||
stats = dict(
|
stats = dict(
|
||||||
discriminator_loss=loss.item(),
|
discriminator_loss=loss.item(),
|
||||||
discriminator_real_loss=real_loss.item(),
|
discriminator_real_loss=real_loss.item(),
|
||||||
@ -551,7 +549,92 @@ class VITS(nn.Module):
|
|||||||
if reuse_cache or not self.training:
|
if reuse_cache or not self.training:
|
||||||
self._cache = None
|
self._cache = None
|
||||||
|
|
||||||
return loss, dur_loss, stats
|
return loss, stats
|
||||||
|
|
||||||
|
def _forward_discrminator_duration(
|
||||||
|
self,
|
||||||
|
text: torch.Tensor,
|
||||||
|
text_lengths: torch.Tensor,
|
||||||
|
feats: torch.Tensor,
|
||||||
|
feats_lengths: torch.Tensor,
|
||||||
|
speech: torch.Tensor,
|
||||||
|
speech_lengths: torch.Tensor,
|
||||||
|
sids: Optional[torch.Tensor] = None,
|
||||||
|
spembs: Optional[torch.Tensor] = None,
|
||||||
|
lids: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||||
|
"""Perform discriminator forward.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (Tensor): Text index tensor (B, T_text).
|
||||||
|
text_lengths (Tensor): Text length tensor (B,).
|
||||||
|
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
|
||||||
|
feats_lengths (Tensor): Feature length tensor (B,).
|
||||||
|
speech (Tensor): Speech waveform tensor (B, T_wav).
|
||||||
|
speech_lengths (Tensor): Speech length tensor (B,).
|
||||||
|
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||||
|
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||||
|
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
* loss (Tensor): Loss scalar tensor.
|
||||||
|
* stats (Dict[str, float]): Statistics to be monitored.
|
||||||
|
"""
|
||||||
|
# setup
|
||||||
|
feats = feats.transpose(1, 2)
|
||||||
|
speech = speech.unsqueeze(1)
|
||||||
|
|
||||||
|
# calculate generator outputs
|
||||||
|
reuse_cache = True
|
||||||
|
if not self.cache_generator_outputs or self._cache is None:
|
||||||
|
reuse_cache = False
|
||||||
|
outs = self.generator(
|
||||||
|
text=text,
|
||||||
|
text_lengths=text_lengths,
|
||||||
|
feats=feats,
|
||||||
|
feats_lengths=feats_lengths,
|
||||||
|
sids=sids,
|
||||||
|
spembs=spembs,
|
||||||
|
lids=lids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
outs = self._cache
|
||||||
|
|
||||||
|
# store cache
|
||||||
|
if self.cache_generator_outputs and not reuse_cache:
|
||||||
|
self._cache = outs
|
||||||
|
|
||||||
|
(
|
||||||
|
speech_hat_,
|
||||||
|
dur_nll,
|
||||||
|
attn,
|
||||||
|
start_idxs,
|
||||||
|
x_mask,
|
||||||
|
y_mask,
|
||||||
|
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||||
|
(hidden_x, logw, logw_),
|
||||||
|
) = outs
|
||||||
|
|
||||||
|
# Duration Discriminator
|
||||||
|
y_dur_hat_r, y_dur_hat_g = self.dur_disc(
|
||||||
|
hidden_x.detach(), x_mask.detach(), logw_.detach(), logw.detach()
|
||||||
|
)
|
||||||
|
|
||||||
|
with autocast(enabled=False):
|
||||||
|
real_dur_loss, fake_dur_loss = self.discriminator_adv_loss(
|
||||||
|
y_dur_hat_g, y_dur_hat_r
|
||||||
|
)
|
||||||
|
dur_loss = real_dur_loss + fake_dur_loss
|
||||||
|
|
||||||
|
stats = dict(
|
||||||
|
discriminator_dur_loss=dur_loss.item(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# reset cache
|
||||||
|
if reuse_cache or not self.training:
|
||||||
|
self._cache = None
|
||||||
|
|
||||||
|
return dur_loss, stats
|
||||||
|
|
||||||
def inference(
|
def inference(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user