diff --git a/egs/ljspeech/TTS/vits2/duration_predictor.py b/egs/ljspeech/TTS/vits2/duration_predictor.py index 1a8190014..14a03b7f8 100644 --- a/egs/ljspeech/TTS/vits2/duration_predictor.py +++ b/egs/ljspeech/TTS/vits2/duration_predictor.py @@ -20,6 +20,7 @@ from flow import ( ElementwiseAffineFlow, FlipFlow, LogFlow, + Transpose, ) @@ -191,3 +192,68 @@ class StochasticDurationPredictor(torch.nn.Module): z0, z1 = z.split(1, 1) logw = z0 return logw + + +class DurationPredictor(torch.nn.Module): + def __init__( + self, + input_channels: int = 192, + output_channels: int = 192, + kernel_size: int = 3, + dropout_rate: float = 0.5, + global_channels: int = -1, + eps: float = 1e-5, + ): + super().__init__() + + self.input_channels = input_channels + self.output_channels = output_channels + self.kernel_size = kernel_size + self.dropout_rate = dropout_rate + self.gin_channels = global_channels + + self.dropout = torch.nn.Dropout(dropout_rate) + self.conv_1 = torch.nn.Conv1d( + input_channels, output_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_1 = torch.nn.Sequential( + Transpose(1, 2), + torch.nn.LayerNorm( + output_channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + ) + self.conv_2 = torch.nn.Conv1d( + output_channels, output_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_2 = torch.nn.Sequential( + Transpose(1, 2), + torch.nn.LayerNorm( + output_channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + ) + self.proj = torch.nn.Conv1d(output_channels, 1, 1) + + if global_channels > 0: + self.cond = torch.nn.Conv1d(global_channels, input_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.dropout(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.dropout(x) + x = self.proj(x * x_mask) + return x * x_mask diff --git a/egs/ljspeech/TTS/vits2/generator.py b/egs/ljspeech/TTS/vits2/generator.py index 15f5f5187..d437d21fd 100644 --- a/egs/ljspeech/TTS/vits2/generator.py +++ b/egs/ljspeech/TTS/vits2/generator.py @@ -16,7 +16,7 @@ from typing import List, Optional, Tuple import numpy as np import torch import torch.nn.functional as F -from duration_predictor import StochasticDurationPredictor +from duration_predictor import DurationPredictor, StochasticDurationPredictor from hifigan import HiFiGANGenerator from posterior_encoder import PosteriorEncoder from residual_coupling import ResidualAffineCouplingBlock @@ -71,6 +71,8 @@ class VITSGenerator(torch.nn.Module): stochastic_duration_predictor_dropout_rate: float = 0.5, stochastic_duration_predictor_flows: int = 4, stochastic_duration_predictor_dds_conv_layers: int = 3, + duration_predictor_output_channels: int = 256, + use_stochastic_duration_predictor: bool = True, use_noised_mas: bool = True, noise_initial_mas: float = 0.01, noise_scale_mas: float = 2e-6, @@ -184,14 +186,23 @@ class VITSGenerator(torch.nn.Module): use_transformer_in_flows=use_transformer_in_flows, ) # TODO(kan-bayashi): Add deterministic version as an option - self.duration_predictor = StochasticDurationPredictor( - channels=hidden_channels, - kernel_size=stochastic_duration_predictor_kernel_size, - dropout_rate=stochastic_duration_predictor_dropout_rate, - flows=stochastic_duration_predictor_flows, - dds_conv_layers=stochastic_duration_predictor_dds_conv_layers, - global_channels=global_channels, - ) + if use_stochastic_duration_predictor: + self.duration_predictor = StochasticDurationPredictor( + channels=hidden_channels, + kernel_size=stochastic_duration_predictor_kernel_size, + dropout_rate=stochastic_duration_predictor_dropout_rate, + flows=stochastic_duration_predictor_flows, + dds_conv_layers=stochastic_duration_predictor_dds_conv_layers, + global_channels=global_channels, + ) + else: + self.duration_predictor = DurationPredictor( + input_channels=hidden_channels, + output_channels=duration_predictor_output_channels, + kernel_size=stochastic_duration_predictor_kernel_size, + dropout_rate=stochastic_duration_predictor_dropout_rate, + global_channels=global_channels, + ) self.upsample_factor = int(np.prod(decoder_upsample_scales)) @@ -200,6 +211,7 @@ class VITSGenerator(torch.nn.Module): self.noise_current_mas = noise_initial_mas self.noise_scale_mas = noise_scale_mas self.noise_initial_mas = noise_initial_mas + self.use_stochastic_duration_predictor = use_stochastic_duration_predictor self.spks = None if spks is not None and spks > 1: @@ -354,8 +366,18 @@ class VITSGenerator(torch.nn.Module): # forward duration predictor w = attn.sum(2) # (B, 1, T_text) - dur_nll = self.duration_predictor(x, x_mask, w=w, g=g) - dur_nll = dur_nll / torch.sum(x_mask) + + if self.use_stochastic_duration_predictor: + dur_nll = self.duration_predictor(x, x_mask, w=w, g=g) + dur_nll = dur_nll / torch.sum(x_mask) + logw = self.duration_predictor( + x, x_mask, g=g, inverse=True, noise_scale=1.0 + ) + logw_ = torch.log(w + 1e-6) * x_mask + else: + logw_ = torch.log(w + 1e-6) * x_mask + logw = self.dp(x, x_mask, g=g) + dur_nll = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask) # expand the length to match with the feature sequence # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) @@ -381,6 +403,7 @@ class VITSGenerator(torch.nn.Module): x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), + (x, logw, logw_), ) def inference( diff --git a/egs/ljspeech/TTS/vits2/hifigan.py b/egs/ljspeech/TTS/vits2/hifigan.py index 589ac30f6..cb02a1494 100644 --- a/egs/ljspeech/TTS/vits2/hifigan.py +++ b/egs/ljspeech/TTS/vits2/hifigan.py @@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional import numpy as np import torch import torch.nn.functional as F +from flow import Transpose class HiFiGANGenerator(torch.nn.Module): @@ -931,3 +932,136 @@ class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module): msd_outs = self.msd(x) mpd_outs = self.mpd(x) return msd_outs + mpd_outs + + +class DurationDiscriminator(torch.nn.Module): # vits2 + def __init__( + self, + channels: int = 192, + hidden_channels: int = 192, + kernel_size: int = 3, + dropout_rate: float = 0.5, + eps: float = 1e-5, + global_channels: int = -1, + ): + super().__init__() + + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dropout_rate = dropout_rate + self.global_channels = global_channels + + self.dropout = torch.nn.Dropout(dropout_rate) + self.conv_1 = torch.nn.Conv1d( + channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_1 = torch.nn.Sequential( + Transpose(1, 2), + torch.nn.LayerNorm( + hidden_channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + ) + + self.conv_2 = torch.nn.Conv1d( + hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + + self.norm_2 = torch.nn.Sequential( + Transpose(1, 2), + torch.nn.LayerNorm( + hidden_channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + ) + + self.dur_proj = torch.nn.Conv1d(1, hidden_channels, 1) + + self.pre_out_conv_1 = torch.nn.Conv1d( + 2 * hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + + self.pre_out_norm_1 = torch.nn.Sequential( + Transpose(1, 2), + torch.nn.LayerNorm( + hidden_channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + ) + + self.pre_out_conv_2 = torch.nn.Conv1d( + hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + + self.pre_out_norm_2 = torch.nn.Sequential( + Transpose(1, 2), + torch.nn.LayerNorm( + hidden_channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + ) + + if global_channels > 0: + self.cond_layer = torch.nn.Conv1d(global_channels, channels, 1) + + self.output_layer = torch.nn.Sequential( + torch.nn.Linear(hidden_channels, 1), torch.nn.Sigmoid() + ) + + def forward_probability( + self, x: torch.Tensor, x_mask: torch.Tensor, dur: torch.Tensor + ): + dur = self.dur_proj(dur) + x = torch.cat([x, dur], dim=1) + x = self.pre_out_conv_1(x * x_mask) + x = torch.relu(x) + x = self.pre_out_norm_1(x) + x = self.dropout(x) + x = self.pre_out_conv_2(x * x_mask) + x = torch.relu(x) + x = self.pre_out_norm_2(x) + x = self.dropout(x) + x = x * x_mask + x = x.transpose(1, 2) + output_prob = self.output_layer(x) + return output_prob + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + dur_r: torch.Tensor, + dur_hat: torch.Tensor, + g: Optional[torch.Tensor] = None, + ): + + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond_layer(g) + + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.dropout(x) + + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.dropout(x) + + output_probs = [] + for dur in [dur_r, dur_hat]: + output_prob = self.forward_probability(x, x_mask, dur) + output_probs.append(output_prob) + + return output_probs diff --git a/egs/ljspeech/TTS/vits2/loss.py b/egs/ljspeech/TTS/vits2/loss.py index 2f4dc9bc0..63e779a9a 100644 --- a/egs/ljspeech/TTS/vits2/loss.py +++ b/egs/ljspeech/TTS/vits2/loss.py @@ -333,3 +333,30 @@ class KLDivergenceLossWithoutFlow(torch.nn.Module): prior_norm = D.Normal(m_p, torch.exp(logs_p)) loss = D.kl_divergence(posterior_norm, prior_norm).mean() return loss + + +class DurationDiscLoss(torch.nn.Module): + def forward( + self, + disc_real_outputs: List[torch.Tensor], + disc_generated_outputs: List[torch.Tensor], + ): + loss = 0 + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + dr = dr.float() + dg = dg.float() + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + + return loss + + +class DurationGenLoss(torch.nn.Module): + def forward(self, disc_outputs: List[torch.Tensor]): + loss = 0 + for dg in disc_outputs: + dg = dg.float() + loss += torch.mean((1 - dg) ** 2) + + return loss diff --git a/egs/ljspeech/TTS/vits2/train.py b/egs/ljspeech/TTS/vits2/train.py index 8cdbc4623..cb5c4f952 100755 --- a/egs/ljspeech/TTS/vits2/train.py +++ b/egs/ljspeech/TTS/vits2/train.py @@ -314,8 +314,10 @@ def train_one_epoch( tokenizer: Tokenizer, optimizer_g: Optimizer, optimizer_d: Optimizer, + optimizer_dur: Optimizer, scheduler_g: LRSchedulerType, scheduler_d: LRSchedulerType, + scheduler_dur: LRSchedulerType, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, @@ -402,7 +404,7 @@ def train_one_epoch( try: with autocast(enabled=params.use_fp16): # forward discriminator - loss_d, stats_d = model( + loss_d, dur_loss, stats_d = model( text=tokens, text_lengths=tokens_lens, feats=features, @@ -411,6 +413,11 @@ def train_one_epoch( speech_lengths=audio_lens, forward_generator=False, ) + + optimizer_dur.zero_grad() + scaler.scale(dur_loss).backward() + scaler.step(optimizer_dur) + for k, v in stats_d.items(): loss_info[k] = v * batch_size # update discriminator @@ -597,7 +604,7 @@ def compute_validation_loss( loss_info["samples"] = batch_size # forward discriminator - loss_d, stats_d = model( + loss_d, dur_loss, stats_d = model( text=tokens, text_lengths=tokens_lens, feats=features, @@ -661,6 +668,7 @@ def scan_pessimistic_batches_for_oom( tokenizer: Tokenizer, optimizer_g: torch.optim.Optimizer, optimizer_d: torch.optim.Optimizer, + optimizer_dur: torch.optim.Optimizer, params: AttributeDict, ): from lhotse.dataset import find_pessimistic_batches @@ -678,7 +686,7 @@ def scan_pessimistic_batches_for_oom( try: # for discriminator with autocast(enabled=params.use_fp16): - loss_d, stats_d = model( + loss_d, dur_loss, stats_d = model( text=tokens, text_lengths=tokens_lens, feats=features, @@ -687,6 +695,10 @@ def scan_pessimistic_batches_for_oom( speech_lengths=audio_lens, forward_generator=False, ) + + optimizer_dur.zero_grad() + dur_loss.backward() + optimizer_d.zero_grad() loss_d.backward() # for generator @@ -760,12 +772,17 @@ def run(rank, world_size, args): model = get_model(params) generator = model.generator discriminator = model.discriminator + dur_disc = model.dur_disc num_param_g = sum([p.numel() for p in generator.parameters()]) logging.info(f"Number of parameters in generator: {num_param_g}") num_param_d = sum([p.numel() for p in discriminator.parameters()]) logging.info(f"Number of parameters in discriminator: {num_param_d}") - logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + num_param_dur = sum([p.numel() for p in dur_disc.parameters()]) + logging.info(f"Number of parameters in duration discriminator: {num_param_dur}") + logging.info( + f"Total number of parameters: {num_param_g + num_param_d + num_param_dur}" + ) assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available(params=params, model=model) @@ -781,9 +798,15 @@ def run(rank, world_size, args): optimizer_d = torch.optim.AdamW( discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 ) + optimizer_dur = torch.optim.AdamW( + dur_disc.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) + scheduler_dur = torch.optim.lr_scheduler.ExponentialLR( + optimizer_dur, gamma=0.999875 + ) if checkpoints is not None: # load state_dict for optimizers @@ -793,6 +816,9 @@ def run(rank, world_size, args): if "optimizer_d" in checkpoints: logging.info("Loading optimizer_d state dict") optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + if "optimizer_dur" in checkpoints: + logging.info("Loading optimizer_dur state dict") + optimizer_dur.load_state_dict(checkpoints["optimizer_dur"]) # load state_dict for schedulers if "scheduler_g" in checkpoints: @@ -801,6 +827,9 @@ def run(rank, world_size, args): if "scheduler_d" in checkpoints: logging.info("Loading scheduler_d state dict") scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + if "scheduler_dur" in checkpoints: + logging.info("Loading scheduler_dur state dict") + scheduler_dur.load_state_dict(checkpoints["scheduler_dur"]) if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( @@ -812,7 +841,6 @@ def run(rank, world_size, args): register_inf_check_hooks(model) ljspeech = LJSpeechTtsDataModule(args) - train_cuts = ljspeech.train_cuts() def remove_short_and_long_utt(c: Cut): @@ -840,6 +868,7 @@ def run(rank, world_size, args): tokenizer=tokenizer, optimizer_g=optimizer_g, optimizer_d=optimizer_d, + optimizer_dur=optimizer_dur, params=params, ) @@ -865,8 +894,10 @@ def run(rank, world_size, args): tokenizer=tokenizer, optimizer_g=optimizer_g, optimizer_d=optimizer_d, + optimizer_dur=optimizer_dur, scheduler_g=scheduler_g, scheduler_d=scheduler_d, + scheduler_dur=scheduler_dur, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, @@ -905,6 +936,7 @@ def run(rank, world_size, args): # step per epoch scheduler_g.step() scheduler_d.step() + scheduler_dur.step() logging.info("Done!") diff --git a/egs/ljspeech/TTS/vits2/vits.py b/egs/ljspeech/TTS/vits2/vits.py index 02087d07c..fa7b529c8 100644 --- a/egs/ljspeech/TTS/vits2/vits.py +++ b/egs/ljspeech/TTS/vits2/vits.py @@ -11,6 +11,7 @@ import torch import torch.nn as nn from generator import VITSGenerator from hifigan import ( + DurationDiscriminator, HiFiGANMultiPeriodDiscriminator, HiFiGANMultiScaleDiscriminator, HiFiGANMultiScaleMultiPeriodDiscriminator, @@ -19,6 +20,8 @@ from hifigan import ( ) from loss import ( DiscriminatorAdversarialLoss, + DurationDiscLoss, + DurationGenLoss, FeatureMatchLoss, GeneratorAdversarialLoss, KLDivergenceLoss, @@ -87,6 +90,8 @@ class VITS(nn.Module): "stochastic_duration_predictor_dropout_rate": 0.5, "stochastic_duration_predictor_flows": 4, "stochastic_duration_predictor_dds_conv_layers": 3, + "duration_predictor_output_channels": 256, + "use_stochastic_duration_predictor": True, "use_noised_mas": True, "noise_initial_mas": 0.01, "noise_scale_mas": 2e-06, @@ -130,6 +135,13 @@ class VITS(nn.Module): "use_weight_norm": True, "use_spectral_norm": False, }, + "duration_discriminator_params": { + "channels": 192, + "hidden_channels": 192, + "kernel_size": 3, + "dropout_rate": 0.1, + "global_channels": -1, + }, }, # loss related generator_adv_loss_params: Dict[str, Any] = { @@ -155,6 +167,7 @@ class VITS(nn.Module): lambda_feat_match: float = 2.0, lambda_dur: float = 1.0, lambda_kl: float = 1.0, + lambda_dur_gen: float = 1.0, cache_generator_outputs: bool = True, ): """Initialize VITS module. @@ -194,6 +207,13 @@ class VITS(nn.Module): # where idim represents #vocabularies and odim represents # the input acoustic feature dimension. generator_params.update(vocabs=vocab_size, aux_channels=feature_dim) + + self.dur_disc = DurationDiscriminator( + **discriminator_params["duration_discriminator_params"] + ) + + discriminator_params.pop("duration_discriminator_params") + self.generator = generator_class( **generator_params, ) @@ -216,12 +236,17 @@ class VITS(nn.Module): ) self.kl_loss = KLDivergenceLoss() + # Vits2 duration disc + self.dur_disc_loss = DurationDiscLoss() + self.dur_gen_loss = DurationGenLoss() + # coefficients self.lambda_adv = lambda_adv self.lambda_mel = lambda_mel self.lambda_kl = lambda_kl self.lambda_feat_match = lambda_feat_match self.lambda_dur = lambda_dur + self.lambda_dur_gen = lambda_dur_gen # cache self.cache_generator_outputs = cache_generator_outputs @@ -349,8 +374,18 @@ class VITS(nn.Module): self._cache = outs # parse outputs - speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs - _, z_p, m_p, logs_p, _, logs_q = outs_ + # speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs + # _, z_p, m_p, logs_p, _, logs_q = outs_ + ( + speech_hat_, + dur_nll, + attn, + start_idxs, + x_mask, + y_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + (hidden_x, logw, logw_), + ) = outs speech_ = get_segments( x=speech, start_idxs=start_idxs * self.generator.upsample_factor, @@ -371,17 +406,29 @@ class VITS(nn.Module): mel_loss, (mel_hat_, mel_) = self.mel_loss( speech_hat_, speech_, return_mel=True ) - kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask) + kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, y_mask) dur_loss = torch.sum(dur_nll.float()) adv_loss = self.generator_adv_loss(p_hat) feat_match_loss = self.feat_match_loss(p_hat, p) + y_dur_hat_r, y_dur_hat_g = self.dur_disc(hidden_x, x_mask, logw_, logw) + dur_gen_loss = self.dur_gen_loss(y_dur_hat_g) + mel_loss = mel_loss * self.lambda_mel kl_loss = kl_loss * self.lambda_kl dur_loss = dur_loss * self.lambda_dur adv_loss = adv_loss * self.lambda_adv feat_match_loss = feat_match_loss * self.lambda_feat_match - loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss + dur_gen_loss = dur_gen_loss * self.lambda_dur_gen + + loss = ( + mel_loss + + kl_loss + + dur_loss + + adv_loss + + feat_match_loss + + dur_gen_loss + ) stats = dict( generator_loss=loss.item(), @@ -390,6 +437,7 @@ class VITS(nn.Module): generator_dur_loss=dur_loss.item(), generator_adv_loss=adv_loss.item(), generator_feat_match_loss=feat_match_loss.item(), + generator_dur_gen_loss=dur_gen_loss.item(), ) if return_sample: @@ -459,8 +507,17 @@ class VITS(nn.Module): if self.cache_generator_outputs and not reuse_cache: self._cache = outs - # parse outputs - speech_hat_, _, _, start_idxs, *_ = outs + ( + speech_hat_, + dur_nll, + attn, + start_idxs, + x_mask, + y_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + (hidden_x, logw, logw_), + ) = outs + speech_ = get_segments( x=speech, start_idxs=start_idxs * self.generator.upsample_factor, @@ -476,6 +533,14 @@ class VITS(nn.Module): real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) loss = real_loss + fake_loss + # Duration Discriminator + y_dur_hat_r, y_dur_hat_g = self.dur_disc( + hidden_x.detach(), x_mask.detach(), logw_.detach(), logw.detach() + ) + + with autocast(enabled=False): + dur_loss = self.dur_disc_loss(y_dur_hat_r, y_dur_hat_g) + stats = dict( discriminator_loss=loss.item(), discriminator_real_loss=real_loss.item(), @@ -486,7 +551,7 @@ class VITS(nn.Module): if reuse_cache or not self.training: self._cache = None - return loss, stats + return loss, dur_loss, stats def inference( self, diff --git a/egs/ljspeech/TTS/vits2/wavenet.py b/egs/ljspeech/TTS/vits2/wavenet.py index 5db461d5c..98fd775f5 100644 --- a/egs/ljspeech/TTS/vits2/wavenet.py +++ b/egs/ljspeech/TTS/vits2/wavenet.py @@ -103,9 +103,9 @@ class WaveNet(torch.nn.Module): # define output layers if self.use_last_conv: self.last_conv = torch.nn.Sequential( - torch.nn.ReLU(inplace=True), + torch.nn.ReLU(inplace=False), Conv1d1x1(skip_channels, skip_channels, bias=True), - torch.nn.ReLU(inplace=True), + torch.nn.ReLU(inplace=False), Conv1d1x1(skip_channels, out_channels, bias=True), )