# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import List import torch import torch.nn as nn from base_discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT from torch.nn import AvgPool1d class MultiPeriodDiscriminator(nn.Module): def __init__(self): super(MultiPeriodDiscriminator, self).__init__() self.discriminators = nn.ModuleList( [ DiscriminatorP(2), DiscriminatorP(3), DiscriminatorP(5), DiscriminatorP(7), DiscriminatorP(11), ] ) def forward(self, y, y_hat): y_d_rs = [] y_d_gs = [] fmap_rs = [] fmap_gs = [] for i, d in enumerate(self.discriminators): y_d_r, fmap_r = d(y) y_d_g, fmap_g = d(y_hat) y_d_rs.append(y_d_r) fmap_rs.append(fmap_r) y_d_gs.append(y_d_g) fmap_gs.append(fmap_g) return y_d_rs, y_d_gs, fmap_rs, fmap_gs class MultiScaleDiscriminator(nn.Module): def __init__(self): super(MultiScaleDiscriminator, self).__init__() self.discriminators = nn.ModuleList( [ DiscriminatorS(), DiscriminatorS(), DiscriminatorS(), ] ) self.meanpools = nn.ModuleList( [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)] ) def forward(self, y, y_hat): y_d_rs = [] y_d_gs = [] fmap_rs = [] fmap_gs = [] for i, d in enumerate(self.discriminators): if i != 0: y = self.meanpools[i - 1](y) y_hat = self.meanpools[i - 1](y_hat) y_d_r, fmap_r = d(y) y_d_g, fmap_g = d(y_hat) y_d_rs.append(y_d_r) fmap_rs.append(fmap_r) y_d_gs.append(y_d_g) fmap_gs.append(fmap_g) return y_d_rs, y_d_gs, fmap_rs, fmap_gs class MultiScaleSTFTDiscriminator(nn.Module): """Multi-Scale STFT (MS-STFT) discriminator. Args: filters (int): Number of filters in convolutions in_channels (int): Number of input channels. Default: 1 out_channels (int): Number of output channels. Default: 1 n_ffts (Sequence[int]): Size of FFT for each scale hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale win_lengths (Sequence[int]): Window size for each scale **kwargs: additional args for STFTDiscriminator """ def __init__( self, n_filters: int, in_channels: int = 1, out_channels: int = 1, n_ffts: List[int] = [1024, 2048, 512, 256, 128], hop_lengths: List[int] = [256, 512, 128, 64, 32], win_lengths: List[int] = [1024, 2048, 512, 256, 128], **kwargs ): super().__init__() assert len(n_ffts) == len(hop_lengths) == len(win_lengths) self.discriminators = nn.ModuleList( [ DiscriminatorSTFT( n_filters, in_channels=in_channels, out_channels=out_channels, n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs ) for i in range(len(n_ffts)) ] ) self.num_discriminators = len(self.discriminators) def forward(self, x: torch.Tensor): logits = [] fmaps = [] for disc in self.discriminators: logit, fmap = disc(x) logits.append(logit) fmaps.append(fmap) return logits, fmaps