diff --git a/egs/ljspeech/TTS/vits2/generator.py b/egs/ljspeech/TTS/vits2/generator.py index d437d21fd..17fd513a2 100644 --- a/egs/ljspeech/TTS/vits2/generator.py +++ b/egs/ljspeech/TTS/vits2/generator.py @@ -77,6 +77,7 @@ class VITSGenerator(torch.nn.Module): noise_initial_mas: float = 0.01, noise_scale_mas: float = 2e-6, use_transformer_in_flows: bool = True, + use_spk_conditioned_txt_encoder: bool = False, ): """Initialize VITS generator module. @@ -138,6 +139,13 @@ class VITSGenerator(torch.nn.Module): """ super().__init__() self.segment_size = segment_size + self.use_spk_conditioned_txt_encoder = use_spk_conditioned_txt_encoder + + if self.use_spk_conditioned_txt_encoder and global_channels > 0: + self.text_encoder_global_channels = global_channels + else: + self.text_encoder_global_channels = 0 + self.text_encoder = TextEncoder( vocabs=vocabs, d_model=hidden_channels, @@ -146,6 +154,7 @@ class VITSGenerator(torch.nn.Module): cnn_module_kernel=text_encoder_cnn_module_kernel, num_layers=text_encoder_blocks, dropout=text_encoder_dropout_rate, + global_channels=self.text_encoder_global_channels, ) self.decoder = HiFiGANGenerator( in_channels=hidden_channels, @@ -286,9 +295,6 @@ class VITSGenerator(torch.nn.Module): - Tensor: Posterior encoder projected scale (B, H, T_feats). """ - # forward text encoder - x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) - # calculate global conditioning g = None if self.spks is not None: @@ -309,6 +315,9 @@ class VITSGenerator(torch.nn.Module): else: g = g + g_ + # forward text encoder + x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths, g) + # forward posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g) @@ -376,7 +385,7 @@ class VITSGenerator(torch.nn.Module): logw_ = torch.log(w + 1e-6) * x_mask else: logw_ = torch.log(w + 1e-6) * x_mask - logw = self.dp(x, x_mask, g=g) + logw = self.duration_predictor(x, x_mask, g=g) dur_nll = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask) # expand the length to match with the feature sequence diff --git a/egs/ljspeech/TTS/vits2/loss.py b/egs/ljspeech/TTS/vits2/loss.py index 63e779a9a..653e06c0f 100644 --- a/egs/ljspeech/TTS/vits2/loss.py +++ b/egs/ljspeech/TTS/vits2/loss.py @@ -193,7 +193,7 @@ class FeatureMatchLoss(torch.nn.Module): from generator's outputs. feats (Union[List[List[Tensor]], List[Tensor]]): List of list of discriminator outputs or list of discriminator outputs calcuated - from groundtruth.. + from groundtruth. Returns: Tensor: Feature matching loss value. @@ -333,30 +333,3 @@ class KLDivergenceLossWithoutFlow(torch.nn.Module): prior_norm = D.Normal(m_p, torch.exp(logs_p)) loss = D.kl_divergence(posterior_norm, prior_norm).mean() return loss - - -class DurationDiscLoss(torch.nn.Module): - def forward( - self, - disc_real_outputs: List[torch.Tensor], - disc_generated_outputs: List[torch.Tensor], - ): - loss = 0 - for dr, dg in zip(disc_real_outputs, disc_generated_outputs): - dr = dr.float() - dg = dg.float() - r_loss = torch.mean((1 - dr) ** 2) - g_loss = torch.mean(dg**2) - loss += r_loss + g_loss - - return loss - - -class DurationGenLoss(torch.nn.Module): - def forward(self, disc_outputs: List[torch.Tensor]): - loss = 0 - for dg in disc_outputs: - dg = dg.float() - loss += torch.mean((1 - dg) ** 2) - - return loss diff --git a/egs/ljspeech/TTS/vits2/text_encoder.py b/egs/ljspeech/TTS/vits2/text_encoder.py index fcbae7103..3f0469b6e 100644 --- a/egs/ljspeech/TTS/vits2/text_encoder.py +++ b/egs/ljspeech/TTS/vits2/text_encoder.py @@ -49,6 +49,8 @@ class TextEncoder(torch.nn.Module): cnn_module_kernel: int = 5, num_layers: int = 6, dropout: float = 0.1, + global_channels: int = 0, + cond_layer_idx: int = 2, ): """Initialize TextEncoder module. @@ -63,6 +65,8 @@ class TextEncoder(torch.nn.Module): """ super().__init__() self.d_model = d_model + self.global_channels = global_channels + self.cond_layer_idx = cond_layer_idx # define modules self.emb = torch.nn.Embedding(vocabs, d_model) @@ -76,14 +80,14 @@ class TextEncoder(torch.nn.Module): cnn_module_kernel=cnn_module_kernel, num_layers=num_layers, dropout=dropout, + global_channels=global_channels, + cond_layer_idx=cond_layer_idx, ) self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1) def forward( - self, - x: torch.Tensor, - x_lengths: torch.Tensor, + self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Calculate forward propagation. @@ -107,7 +111,7 @@ class TextEncoder(torch.nn.Module): pad_mask = make_pad_mask(x_lengths) # encoder assume the channel last (B, T_text, embed_dim) - x = self.encoder(x, key_padding_mask=pad_mask) + x = self.encoder(x, key_padding_mask=pad_mask, g=g) # convert the channel first (B, embed_dim, T_text) x = x.transpose(1, 2) @@ -137,11 +141,18 @@ class Transformer(nn.Module): cnn_module_kernel: int = 5, num_layers: int = 6, dropout: float = 0.1, + global_channels: int = 0, + cond_layer_idx: int = 2, ) -> None: super().__init__() self.num_layers = num_layers self.d_model = d_model + self.global_channels = global_channels + + speaker_embedder = None + if self.global_channels != 0: + speaker_embedder = nn.Linear(self.global_channels, self.d_model) self.encoder_pos = RelPositionalEncoding(d_model, dropout) @@ -151,12 +162,13 @@ class Transformer(nn.Module): dim_feedforward=dim_feedforward, cnn_module_kernel=cnn_module_kernel, dropout=dropout, + speaker_embedder=speaker_embedder, ) - self.encoder = TransformerEncoder(encoder_layer, num_layers) + self.encoder = TransformerEncoder(encoder_layer, num_layers, cond_layer_idx) self.after_norm = nn.LayerNorm(d_model) def forward( - self, x: Tensor, key_padding_mask: Tensor + self, x: Tensor, key_padding_mask: Tensor, g: Optional[Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -169,7 +181,9 @@ class Transformer(nn.Module): x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C) + x = self.encoder( + x, pos_emb, key_padding_mask=key_padding_mask, g=g + ) # (T, N, C) x = self.after_norm(x) @@ -195,6 +209,7 @@ class TransformerEncoderLayer(nn.Module): dim_feedforward: int, cnn_module_kernel: int, dropout: float = 0.1, + speaker_embedder: Optional[nn.Module] = None, ) -> None: super(TransformerEncoderLayer, self).__init__() @@ -227,11 +242,15 @@ class TransformerEncoderLayer(nn.Module): self.ff_scale = 0.5 self.dropout = nn.Dropout(dropout) + self.speaker_embedder = speaker_embedder + def forward( self, src: Tensor, pos_emb: Tensor, key_padding_mask: Optional[Tensor] = None, + g: Optional[Tensor] = None, + apply_speaker_embedding: bool = False, ) -> Tensor: """ Pass the input through the transformer encoder layer. @@ -241,6 +260,11 @@ class TransformerEncoderLayer(nn.Module): pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim). key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) """ + if g is not None and apply_speaker_embedding: + g = self.speaker_embedder(g.transpose(1, 2)) + g = g.transpose(1, 2) + src = src + g # * src_mask + # macaron style feed-forward module src = src + self.ff_scale * self.dropout( self.feed_forward_macaron(self.norm_ff_macaron(src)) @@ -273,19 +297,23 @@ class TransformerEncoder(nn.Module): num_layers: the number of sub-encoder-layers in the encoder. """ - def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + def __init__( + self, encoder_layer: nn.Module, num_layers: int, cond_layer_idx: int = 0 + ) -> None: super().__init__() self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) self.num_layers = num_layers + self.cond_layer_idx = cond_layer_idx def forward( self, src: Tensor, pos_emb: Tensor, key_padding_mask: Optional[Tensor] = None, + g: Optional[Tensor] = None, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -295,12 +323,17 @@ class TransformerEncoder(nn.Module): key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) """ output = src - for layer_index, mod in enumerate(self.layers): + apply_speaker_embedding = False + if layer_index == self.cond_layer_idx: + apply_speaker_embedding = True + output = mod( output, pos_emb, key_padding_mask=key_padding_mask, + g=g, + apply_speaker_embedding=apply_speaker_embedding, ) return output diff --git a/egs/ljspeech/TTS/vits2/train.py b/egs/ljspeech/TTS/vits2/train.py index cb5c4f952..743720084 100755 --- a/egs/ljspeech/TTS/vits2/train.py +++ b/egs/ljspeech/TTS/vits2/train.py @@ -375,8 +375,10 @@ def train_one_epoch( params=params, optimizer_g=optimizer_g, optimizer_d=optimizer_d, + optimizer_dur=optimizer_dur, scheduler_g=scheduler_g, scheduler_d=scheduler_d, + scheduler_dur=scheduler_dur, sampler=train_dl.sampler, scaler=scaler, rank=0, @@ -393,15 +395,42 @@ def train_one_epoch( loss_info = MetricsTracker() loss_info["samples"] = batch_size - if model.module.generator.use_noised_mas: - # MAS with Gaussian Noise - model.module.generator.noise_current_mas = max( - model.module.generator.noise_initial_mas - - model.module.generator.noise_scale_mas * params.batch_idx_train, - 0.0, - ) + if isinstance(model, DDP): + if model.module.generator.use_noised_mas: + # MAS with Gaussian Noise + model.module.generator.noise_current_mas = max( + model.module.generator.noise_initial_mas + - model.module.generator.noise_scale_mas * params.batch_idx_train, + 0.0, + ) + else: + if model.generator.use_noised_mas: + # MAS with Gaussian Noise + model.generator.noise_current_mas = max( + model.generator.noise_initial_mas + - model.generator.noise_scale_mas * params.batch_idx_train, + 0.0, + ) try: + with autocast(enabled=params.use_fp16): + # forward duration discriminator + dur_loss, stats_dur = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_type="duration_discriminator", + ) + for k, v in stats_dur.items(): + loss_info[k] = v * batch_size + + optimizer_dur.zero_grad() + scaler.scale(dur_loss).backward() + scaler.step(optimizer_dur) + with autocast(enabled=params.use_fp16): # forward discriminator loss_d, dur_loss, stats_d = model( @@ -411,13 +440,9 @@ def train_one_epoch( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, - forward_generator=False, + forward_type="discriminator", ) - optimizer_dur.zero_grad() - scaler.scale(dur_loss).backward() - scaler.step(optimizer_dur) - for k, v in stats_d.items(): loss_info[k] = v * batch_size # update discriminator @@ -434,7 +459,7 @@ def train_one_epoch( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, - forward_generator=True, + forward_type="generator", return_sample=params.batch_idx_train % params.log_interval == 0, ) for k, v in stats_g.items(): @@ -603,15 +628,29 @@ def compute_validation_loss( loss_info = MetricsTracker() loss_info["samples"] = batch_size - # forward discriminator - loss_d, dur_loss, stats_d = model( + # forward duration discriminator + loss_dur, stats_dur = model( text=tokens, text_lengths=tokens_lens, feats=features, feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, - forward_generator=False, + forward_type="duration_discriminator", + ) + assert loss_dur.requires_grad is False + for k, v in stats_dur.items(): + loss_info[k] = v * batch_size + + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_type="discriminator", ) assert loss_d.requires_grad is False for k, v in stats_d.items(): @@ -625,7 +664,7 @@ def compute_validation_loss( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, - forward_generator=True, + forward_type="generator", ) assert loss_g.requires_grad is False for k, v in stats_g.items(): @@ -684,21 +723,33 @@ def scan_pessimistic_batches_for_oom( batch, tokenizer, device ) try: - # for discriminator + # for duration discriminator with autocast(enabled=params.use_fp16): - loss_d, dur_loss, stats_d = model( + dur_loss, stats_dur = model( text=tokens, text_lengths=tokens_lens, feats=features, feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, - forward_generator=False, + forward_type="duration_discriminator", ) optimizer_dur.zero_grad() dur_loss.backward() + # for discriminator + with autocast(enabled=params.use_fp16): + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_type="discriminator", + ) + optimizer_d.zero_grad() loss_d.backward() # for generator @@ -710,7 +761,7 @@ def scan_pessimistic_batches_for_oom( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, - forward_generator=True, + forward_type="generator", ) optimizer_g.zero_grad() loss_g.backward() @@ -918,8 +969,10 @@ def run(rank, world_size, args): model=model, optimizer_g=optimizer_g, optimizer_d=optimizer_d, + optimizer_dur=optimizer_dur, scheduler_g=scheduler_g, scheduler_d=scheduler_d, + scheduler_dur=scheduler_dur, sampler=train_dl.sampler, scaler=scaler, rank=rank, diff --git a/egs/ljspeech/TTS/vits2/utils.py b/egs/ljspeech/TTS/vits2/utils.py index 6a067f596..eae3a47ad 100644 --- a/egs/ljspeech/TTS/vits2/utils.py +++ b/egs/ljspeech/TTS/vits2/utils.py @@ -203,8 +203,10 @@ def save_checkpoint( params: Optional[Dict[str, Any]] = None, optimizer_g: Optional[Optimizer] = None, optimizer_d: Optional[Optimizer] = None, + optimizer_dur: Optional[Optimizer] = None, scheduler_g: Optional[LRSchedulerType] = None, scheduler_d: Optional[LRSchedulerType] = None, + scheduler_dur: Optional[LRSchedulerType] = None, scaler: Optional[GradScaler] = None, sampler: Optional[CutSampler] = None, rank: int = 0, @@ -251,7 +253,13 @@ def save_checkpoint( "model": model.state_dict(), "optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None, "optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None, + "optimizer_dur": optimizer_dur.state_dict() + if optimizer_d is not None + else None, "scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None, + "scheduler_dur": scheduler_dur.state_dict() + if scheduler_d is not None + else None, "scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None, "grad_scaler": scaler.state_dict() if scaler is not None else None, "sampler": sampler.state_dict() if sampler is not None else None, diff --git a/egs/ljspeech/TTS/vits2/vits.py b/egs/ljspeech/TTS/vits2/vits.py index fa7b529c8..14b45ffba 100644 --- a/egs/ljspeech/TTS/vits2/vits.py +++ b/egs/ljspeech/TTS/vits2/vits.py @@ -20,8 +20,6 @@ from hifigan import ( ) from loss import ( DiscriminatorAdversarialLoss, - DurationDiscLoss, - DurationGenLoss, FeatureMatchLoss, GeneratorAdversarialLoss, KLDivergenceLoss, @@ -236,10 +234,6 @@ class VITS(nn.Module): ) self.kl_loss = KLDivergenceLoss() - # Vits2 duration disc - self.dur_disc_loss = DurationDiscLoss() - self.dur_gen_loss = DurationGenLoss() - # coefficients self.lambda_adv = lambda_adv self.lambda_mel = lambda_mel @@ -273,7 +267,7 @@ class VITS(nn.Module): sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, - forward_generator: bool = True, + forward_type: str = "generator", ) -> Tuple[torch.Tensor, Dict[str, Any]]: """Perform generator forward. @@ -287,13 +281,13 @@ class VITS(nn.Module): sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). - forward_generator (bool): Whether to forward generator. + forward_type (str): Which type of forward to do Returns: - loss (Tensor): Loss scalar tensor. - stats (Dict[str, float]): Statistics to be monitored. """ - if forward_generator: + if forward_type == "generator": return self._forward_generator( text=text, text_lengths=text_lengths, @@ -306,7 +300,7 @@ class VITS(nn.Module): spembs=spembs, lids=lids, ) - else: + elif forward_type == "discriminator": return self._forward_discrminator( text=text, text_lengths=text_lengths, @@ -318,6 +312,20 @@ class VITS(nn.Module): spembs=spembs, lids=lids, ) + elif forward_type == "duration_discriminator": + return self._forward_discrminator_duration( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + speech=speech, + speech_lengths=speech_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + else: + raise Exception(f"Forward type {forward_type} does not exist") def _forward_generator( self, @@ -374,8 +382,6 @@ class VITS(nn.Module): self._cache = outs # parse outputs - # speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs - # _, z_p, m_p, logs_p, _, logs_q = outs_ ( speech_hat_, dur_nll, @@ -412,7 +418,7 @@ class VITS(nn.Module): feat_match_loss = self.feat_match_loss(p_hat, p) y_dur_hat_r, y_dur_hat_g = self.dur_disc(hidden_x, x_mask, logw_, logw) - dur_gen_loss = self.dur_gen_loss(y_dur_hat_g) + dur_gen_loss = self.generator_adv_loss(y_dur_hat_g) mel_loss = mel_loss * self.lambda_mel kl_loss = kl_loss * self.lambda_kl @@ -514,8 +520,8 @@ class VITS(nn.Module): start_idxs, x_mask, y_mask, - (z, z_p, m_p, logs_p, m_q, logs_q), - (hidden_x, logw, logw_), + _, + _, ) = outs speech_ = get_segments( @@ -533,14 +539,6 @@ class VITS(nn.Module): real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) loss = real_loss + fake_loss - # Duration Discriminator - y_dur_hat_r, y_dur_hat_g = self.dur_disc( - hidden_x.detach(), x_mask.detach(), logw_.detach(), logw.detach() - ) - - with autocast(enabled=False): - dur_loss = self.dur_disc_loss(y_dur_hat_r, y_dur_hat_g) - stats = dict( discriminator_loss=loss.item(), discriminator_real_loss=real_loss.item(), @@ -551,7 +549,92 @@ class VITS(nn.Module): if reuse_cache or not self.training: self._cache = None - return loss, dur_loss, stats + return loss, stats + + def _forward_discrminator_duration( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """Perform discriminator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + """ + # setup + feats = feats.transpose(1, 2) + speech = speech.unsqueeze(1) + + # calculate generator outputs + reuse_cache = True + if not self.cache_generator_outputs or self._cache is None: + reuse_cache = False + outs = self.generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + else: + outs = self._cache + + # store cache + if self.cache_generator_outputs and not reuse_cache: + self._cache = outs + + ( + speech_hat_, + dur_nll, + attn, + start_idxs, + x_mask, + y_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + (hidden_x, logw, logw_), + ) = outs + + # Duration Discriminator + y_dur_hat_r, y_dur_hat_g = self.dur_disc( + hidden_x.detach(), x_mask.detach(), logw_.detach(), logw.detach() + ) + + with autocast(enabled=False): + real_dur_loss, fake_dur_loss = self.discriminator_adv_loss( + y_dur_hat_g, y_dur_hat_r + ) + dur_loss = real_dur_loss + fake_dur_loss + + stats = dict( + discriminator_dur_loss=dur_loss.item(), + ) + + # reset cache + if reuse_cache or not self.training: + self._cache = None + + return dur_loss, stats def inference( self,