Add speaker embedding in text encoder

This commit is contained in:
Erwan 2024-02-12 17:07:30 +01:00
parent e5c04a216c
commit a4e4f8080a
6 changed files with 245 additions and 86 deletions

View File

@ -77,6 +77,7 @@ class VITSGenerator(torch.nn.Module):
noise_initial_mas: float = 0.01,
noise_scale_mas: float = 2e-6,
use_transformer_in_flows: bool = True,
use_spk_conditioned_txt_encoder: bool = False,
):
"""Initialize VITS generator module.
@ -138,6 +139,13 @@ class VITSGenerator(torch.nn.Module):
"""
super().__init__()
self.segment_size = segment_size
self.use_spk_conditioned_txt_encoder = use_spk_conditioned_txt_encoder
if self.use_spk_conditioned_txt_encoder and global_channels > 0:
self.text_encoder_global_channels = global_channels
else:
self.text_encoder_global_channels = 0
self.text_encoder = TextEncoder(
vocabs=vocabs,
d_model=hidden_channels,
@ -146,6 +154,7 @@ class VITSGenerator(torch.nn.Module):
cnn_module_kernel=text_encoder_cnn_module_kernel,
num_layers=text_encoder_blocks,
dropout=text_encoder_dropout_rate,
global_channels=self.text_encoder_global_channels,
)
self.decoder = HiFiGANGenerator(
in_channels=hidden_channels,
@ -286,9 +295,6 @@ class VITSGenerator(torch.nn.Module):
- Tensor: Posterior encoder projected scale (B, H, T_feats).
"""
# forward text encoder
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
# calculate global conditioning
g = None
if self.spks is not None:
@ -309,6 +315,9 @@ class VITSGenerator(torch.nn.Module):
else:
g = g + g_
# forward text encoder
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths, g)
# forward posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
@ -376,7 +385,7 @@ class VITSGenerator(torch.nn.Module):
logw_ = torch.log(w + 1e-6) * x_mask
else:
logw_ = torch.log(w + 1e-6) * x_mask
logw = self.dp(x, x_mask, g=g)
logw = self.duration_predictor(x, x_mask, g=g)
dur_nll = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
# expand the length to match with the feature sequence

View File

@ -193,7 +193,7 @@ class FeatureMatchLoss(torch.nn.Module):
from generator's outputs.
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
discriminator outputs or list of discriminator outputs calcuated
from groundtruth..
from groundtruth.
Returns:
Tensor: Feature matching loss value.
@ -333,30 +333,3 @@ class KLDivergenceLossWithoutFlow(torch.nn.Module):
prior_norm = D.Normal(m_p, torch.exp(logs_p))
loss = D.kl_divergence(posterior_norm, prior_norm).mean()
return loss
class DurationDiscLoss(torch.nn.Module):
def forward(
self,
disc_real_outputs: List[torch.Tensor],
disc_generated_outputs: List[torch.Tensor],
):
loss = 0
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
dr = dr.float()
dg = dg.float()
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
loss += r_loss + g_loss
return loss
class DurationGenLoss(torch.nn.Module):
def forward(self, disc_outputs: List[torch.Tensor]):
loss = 0
for dg in disc_outputs:
dg = dg.float()
loss += torch.mean((1 - dg) ** 2)
return loss

View File

@ -49,6 +49,8 @@ class TextEncoder(torch.nn.Module):
cnn_module_kernel: int = 5,
num_layers: int = 6,
dropout: float = 0.1,
global_channels: int = 0,
cond_layer_idx: int = 2,
):
"""Initialize TextEncoder module.
@ -63,6 +65,8 @@ class TextEncoder(torch.nn.Module):
"""
super().__init__()
self.d_model = d_model
self.global_channels = global_channels
self.cond_layer_idx = cond_layer_idx
# define modules
self.emb = torch.nn.Embedding(vocabs, d_model)
@ -76,14 +80,14 @@ class TextEncoder(torch.nn.Module):
cnn_module_kernel=cnn_module_kernel,
num_layers=num_layers,
dropout=dropout,
global_channels=global_channels,
cond_layer_idx=cond_layer_idx,
)
self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1)
def forward(
self,
x: torch.Tensor,
x_lengths: torch.Tensor,
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Calculate forward propagation.
@ -107,7 +111,7 @@ class TextEncoder(torch.nn.Module):
pad_mask = make_pad_mask(x_lengths)
# encoder assume the channel last (B, T_text, embed_dim)
x = self.encoder(x, key_padding_mask=pad_mask)
x = self.encoder(x, key_padding_mask=pad_mask, g=g)
# convert the channel first (B, embed_dim, T_text)
x = x.transpose(1, 2)
@ -137,11 +141,18 @@ class Transformer(nn.Module):
cnn_module_kernel: int = 5,
num_layers: int = 6,
dropout: float = 0.1,
global_channels: int = 0,
cond_layer_idx: int = 2,
) -> None:
super().__init__()
self.num_layers = num_layers
self.d_model = d_model
self.global_channels = global_channels
speaker_embedder = None
if self.global_channels != 0:
speaker_embedder = nn.Linear(self.global_channels, self.d_model)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
@ -151,12 +162,13 @@ class Transformer(nn.Module):
dim_feedforward=dim_feedforward,
cnn_module_kernel=cnn_module_kernel,
dropout=dropout,
speaker_embedder=speaker_embedder,
)
self.encoder = TransformerEncoder(encoder_layer, num_layers)
self.encoder = TransformerEncoder(encoder_layer, num_layers, cond_layer_idx)
self.after_norm = nn.LayerNorm(d_model)
def forward(
self, x: Tensor, key_padding_mask: Tensor
self, x: Tensor, key_padding_mask: Tensor, g: Optional[Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
@ -169,7 +181,9 @@ class Transformer(nn.Module):
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C)
x = self.encoder(
x, pos_emb, key_padding_mask=key_padding_mask, g=g
) # (T, N, C)
x = self.after_norm(x)
@ -195,6 +209,7 @@ class TransformerEncoderLayer(nn.Module):
dim_feedforward: int,
cnn_module_kernel: int,
dropout: float = 0.1,
speaker_embedder: Optional[nn.Module] = None,
) -> None:
super(TransformerEncoderLayer, self).__init__()
@ -227,11 +242,15 @@ class TransformerEncoderLayer(nn.Module):
self.ff_scale = 0.5
self.dropout = nn.Dropout(dropout)
self.speaker_embedder = speaker_embedder
def forward(
self,
src: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
g: Optional[Tensor] = None,
apply_speaker_embedding: bool = False,
) -> Tensor:
"""
Pass the input through the transformer encoder layer.
@ -241,6 +260,11 @@ class TransformerEncoderLayer(nn.Module):
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
"""
if g is not None and apply_speaker_embedding:
g = self.speaker_embedder(g.transpose(1, 2))
g = g.transpose(1, 2)
src = src + g # * src_mask
# macaron style feed-forward module
src = src + self.ff_scale * self.dropout(
self.feed_forward_macaron(self.norm_ff_macaron(src))
@ -273,19 +297,23 @@ class TransformerEncoder(nn.Module):
num_layers: the number of sub-encoder-layers in the encoder.
"""
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
def __init__(
self, encoder_layer: nn.Module, num_layers: int, cond_layer_idx: int = 0
) -> None:
super().__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.num_layers = num_layers
self.cond_layer_idx = cond_layer_idx
def forward(
self,
src: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
g: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
@ -295,12 +323,17 @@ class TransformerEncoder(nn.Module):
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
"""
output = src
for layer_index, mod in enumerate(self.layers):
apply_speaker_embedding = False
if layer_index == self.cond_layer_idx:
apply_speaker_embedding = True
output = mod(
output,
pos_emb,
key_padding_mask=key_padding_mask,
g=g,
apply_speaker_embedding=apply_speaker_embedding,
)
return output

View File

@ -375,8 +375,10 @@ def train_one_epoch(
params=params,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
optimizer_dur=optimizer_dur,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
scheduler_dur=scheduler_dur,
sampler=train_dl.sampler,
scaler=scaler,
rank=0,
@ -393,15 +395,42 @@ def train_one_epoch(
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
if model.module.generator.use_noised_mas:
# MAS with Gaussian Noise
model.module.generator.noise_current_mas = max(
model.module.generator.noise_initial_mas
- model.module.generator.noise_scale_mas * params.batch_idx_train,
0.0,
)
if isinstance(model, DDP):
if model.module.generator.use_noised_mas:
# MAS with Gaussian Noise
model.module.generator.noise_current_mas = max(
model.module.generator.noise_initial_mas
- model.module.generator.noise_scale_mas * params.batch_idx_train,
0.0,
)
else:
if model.generator.use_noised_mas:
# MAS with Gaussian Noise
model.generator.noise_current_mas = max(
model.generator.noise_initial_mas
- model.generator.noise_scale_mas * params.batch_idx_train,
0.0,
)
try:
with autocast(enabled=params.use_fp16):
# forward duration discriminator
dur_loss, stats_dur = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_type="duration_discriminator",
)
for k, v in stats_dur.items():
loss_info[k] = v * batch_size
optimizer_dur.zero_grad()
scaler.scale(dur_loss).backward()
scaler.step(optimizer_dur)
with autocast(enabled=params.use_fp16):
# forward discriminator
loss_d, dur_loss, stats_d = model(
@ -411,13 +440,9 @@ def train_one_epoch(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=False,
forward_type="discriminator",
)
optimizer_dur.zero_grad()
scaler.scale(dur_loss).backward()
scaler.step(optimizer_dur)
for k, v in stats_d.items():
loss_info[k] = v * batch_size
# update discriminator
@ -434,7 +459,7 @@ def train_one_epoch(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=True,
forward_type="generator",
return_sample=params.batch_idx_train % params.log_interval == 0,
)
for k, v in stats_g.items():
@ -603,15 +628,29 @@ def compute_validation_loss(
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
# forward discriminator
loss_d, dur_loss, stats_d = model(
# forward duration discriminator
loss_dur, stats_dur = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=False,
forward_type="duration_discriminator",
)
assert loss_dur.requires_grad is False
for k, v in stats_dur.items():
loss_info[k] = v * batch_size
# forward discriminator
loss_d, stats_d = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_type="discriminator",
)
assert loss_d.requires_grad is False
for k, v in stats_d.items():
@ -625,7 +664,7 @@ def compute_validation_loss(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=True,
forward_type="generator",
)
assert loss_g.requires_grad is False
for k, v in stats_g.items():
@ -684,21 +723,33 @@ def scan_pessimistic_batches_for_oom(
batch, tokenizer, device
)
try:
# for discriminator
# for duration discriminator
with autocast(enabled=params.use_fp16):
loss_d, dur_loss, stats_d = model(
dur_loss, stats_dur = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=False,
forward_type="duration_discriminator",
)
optimizer_dur.zero_grad()
dur_loss.backward()
# for discriminator
with autocast(enabled=params.use_fp16):
loss_d, stats_d = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_type="discriminator",
)
optimizer_d.zero_grad()
loss_d.backward()
# for generator
@ -710,7 +761,7 @@ def scan_pessimistic_batches_for_oom(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=True,
forward_type="generator",
)
optimizer_g.zero_grad()
loss_g.backward()
@ -918,8 +969,10 @@ def run(rank, world_size, args):
model=model,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
optimizer_dur=optimizer_dur,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
scheduler_dur=scheduler_dur,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,

View File

@ -203,8 +203,10 @@ def save_checkpoint(
params: Optional[Dict[str, Any]] = None,
optimizer_g: Optional[Optimizer] = None,
optimizer_d: Optional[Optimizer] = None,
optimizer_dur: Optional[Optimizer] = None,
scheduler_g: Optional[LRSchedulerType] = None,
scheduler_d: Optional[LRSchedulerType] = None,
scheduler_dur: Optional[LRSchedulerType] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
@ -251,7 +253,13 @@ def save_checkpoint(
"model": model.state_dict(),
"optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None,
"optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None,
"optimizer_dur": optimizer_dur.state_dict()
if optimizer_d is not None
else None,
"scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None,
"scheduler_dur": scheduler_dur.state_dict()
if scheduler_d is not None
else None,
"scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None,
"grad_scaler": scaler.state_dict() if scaler is not None else None,
"sampler": sampler.state_dict() if sampler is not None else None,

View File

@ -20,8 +20,6 @@ from hifigan import (
)
from loss import (
DiscriminatorAdversarialLoss,
DurationDiscLoss,
DurationGenLoss,
FeatureMatchLoss,
GeneratorAdversarialLoss,
KLDivergenceLoss,
@ -236,10 +234,6 @@ class VITS(nn.Module):
)
self.kl_loss = KLDivergenceLoss()
# Vits2 duration disc
self.dur_disc_loss = DurationDiscLoss()
self.dur_gen_loss = DurationGenLoss()
# coefficients
self.lambda_adv = lambda_adv
self.lambda_mel = lambda_mel
@ -273,7 +267,7 @@ class VITS(nn.Module):
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
forward_generator: bool = True,
forward_type: str = "generator",
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform generator forward.
@ -287,13 +281,13 @@ class VITS(nn.Module):
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
forward_generator (bool): Whether to forward generator.
forward_type (str): Which type of forward to do
Returns:
- loss (Tensor): Loss scalar tensor.
- stats (Dict[str, float]): Statistics to be monitored.
"""
if forward_generator:
if forward_type == "generator":
return self._forward_generator(
text=text,
text_lengths=text_lengths,
@ -306,7 +300,7 @@ class VITS(nn.Module):
spembs=spembs,
lids=lids,
)
else:
elif forward_type == "discriminator":
return self._forward_discrminator(
text=text,
text_lengths=text_lengths,
@ -318,6 +312,20 @@ class VITS(nn.Module):
spembs=spembs,
lids=lids,
)
elif forward_type == "duration_discriminator":
return self._forward_discrminator_duration(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
speech=speech,
speech_lengths=speech_lengths,
sids=sids,
spembs=spembs,
lids=lids,
)
else:
raise Exception(f"Forward type {forward_type} does not exist")
def _forward_generator(
self,
@ -374,8 +382,6 @@ class VITS(nn.Module):
self._cache = outs
# parse outputs
# speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
# _, z_p, m_p, logs_p, _, logs_q = outs_
(
speech_hat_,
dur_nll,
@ -412,7 +418,7 @@ class VITS(nn.Module):
feat_match_loss = self.feat_match_loss(p_hat, p)
y_dur_hat_r, y_dur_hat_g = self.dur_disc(hidden_x, x_mask, logw_, logw)
dur_gen_loss = self.dur_gen_loss(y_dur_hat_g)
dur_gen_loss = self.generator_adv_loss(y_dur_hat_g)
mel_loss = mel_loss * self.lambda_mel
kl_loss = kl_loss * self.lambda_kl
@ -514,8 +520,8 @@ class VITS(nn.Module):
start_idxs,
x_mask,
y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
(hidden_x, logw, logw_),
_,
_,
) = outs
speech_ = get_segments(
@ -533,14 +539,6 @@ class VITS(nn.Module):
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
loss = real_loss + fake_loss
# Duration Discriminator
y_dur_hat_r, y_dur_hat_g = self.dur_disc(
hidden_x.detach(), x_mask.detach(), logw_.detach(), logw.detach()
)
with autocast(enabled=False):
dur_loss = self.dur_disc_loss(y_dur_hat_r, y_dur_hat_g)
stats = dict(
discriminator_loss=loss.item(),
discriminator_real_loss=real_loss.item(),
@ -551,7 +549,92 @@ class VITS(nn.Module):
if reuse_cache or not self.training:
self._cache = None
return loss, dur_loss, stats
return loss, stats
def _forward_discrminator_duration(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
feats: torch.Tensor,
feats_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform discriminator forward.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor): Feature length tensor (B,).
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
"""
# setup
feats = feats.transpose(1, 2)
speech = speech.unsqueeze(1)
# calculate generator outputs
reuse_cache = True
if not self.cache_generator_outputs or self._cache is None:
reuse_cache = False
outs = self.generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids,
)
else:
outs = self._cache
# store cache
if self.cache_generator_outputs and not reuse_cache:
self._cache = outs
(
speech_hat_,
dur_nll,
attn,
start_idxs,
x_mask,
y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
(hidden_x, logw, logw_),
) = outs
# Duration Discriminator
y_dur_hat_r, y_dur_hat_g = self.dur_disc(
hidden_x.detach(), x_mask.detach(), logw_.detach(), logw.detach()
)
with autocast(enabled=False):
real_dur_loss, fake_dur_loss = self.discriminator_adv_loss(
y_dur_hat_g, y_dur_hat_r
)
dur_loss = real_dur_loss + fake_dur_loss
stats = dict(
discriminator_dur_loss=dur_loss.item(),
)
# reset cache
if reuse_cache or not self.training:
self._cache = None
return dur_loss, stats
def inference(
self,