mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
124 lines
3.8 KiB
Python
124 lines
3.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from typing import List
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from base_discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT
|
|
from torch.nn import AvgPool1d
|
|
|
|
|
|
class MultiPeriodDiscriminator(nn.Module):
|
|
def __init__(self):
|
|
super(MultiPeriodDiscriminator, self).__init__()
|
|
self.discriminators = nn.ModuleList(
|
|
[
|
|
DiscriminatorP(2),
|
|
DiscriminatorP(3),
|
|
DiscriminatorP(5),
|
|
DiscriminatorP(7),
|
|
DiscriminatorP(11),
|
|
]
|
|
)
|
|
|
|
def forward(self, y, y_hat):
|
|
y_d_rs = []
|
|
y_d_gs = []
|
|
fmap_rs = []
|
|
fmap_gs = []
|
|
for i, d in enumerate(self.discriminators):
|
|
y_d_r, fmap_r = d(y)
|
|
y_d_g, fmap_g = d(y_hat)
|
|
y_d_rs.append(y_d_r)
|
|
fmap_rs.append(fmap_r)
|
|
y_d_gs.append(y_d_g)
|
|
fmap_gs.append(fmap_g)
|
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
|
|
|
|
|
class MultiScaleDiscriminator(nn.Module):
|
|
def __init__(self):
|
|
super(MultiScaleDiscriminator, self).__init__()
|
|
self.discriminators = nn.ModuleList(
|
|
[
|
|
DiscriminatorS(),
|
|
DiscriminatorS(),
|
|
DiscriminatorS(),
|
|
]
|
|
)
|
|
self.meanpools = nn.ModuleList(
|
|
[AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
|
|
)
|
|
|
|
def forward(self, y, y_hat):
|
|
y_d_rs = []
|
|
y_d_gs = []
|
|
fmap_rs = []
|
|
fmap_gs = []
|
|
for i, d in enumerate(self.discriminators):
|
|
if i != 0:
|
|
y = self.meanpools[i - 1](y)
|
|
y_hat = self.meanpools[i - 1](y_hat)
|
|
y_d_r, fmap_r = d(y)
|
|
y_d_g, fmap_g = d(y_hat)
|
|
y_d_rs.append(y_d_r)
|
|
fmap_rs.append(fmap_r)
|
|
y_d_gs.append(y_d_g)
|
|
fmap_gs.append(fmap_g)
|
|
|
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
|
|
|
|
|
class MultiScaleSTFTDiscriminator(nn.Module):
|
|
"""Multi-Scale STFT (MS-STFT) discriminator.
|
|
Args:
|
|
filters (int): Number of filters in convolutions
|
|
in_channels (int): Number of input channels. Default: 1
|
|
out_channels (int): Number of output channels. Default: 1
|
|
n_ffts (Sequence[int]): Size of FFT for each scale
|
|
hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale
|
|
win_lengths (Sequence[int]): Window size for each scale
|
|
**kwargs: additional args for STFTDiscriminator
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
n_filters: int,
|
|
in_channels: int = 1,
|
|
out_channels: int = 1,
|
|
n_ffts: List[int] = [1024, 2048, 512, 256, 128],
|
|
hop_lengths: List[int] = [256, 512, 128, 64, 32],
|
|
win_lengths: List[int] = [1024, 2048, 512, 256, 128],
|
|
**kwargs
|
|
):
|
|
super().__init__()
|
|
assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
|
|
self.discriminators = nn.ModuleList(
|
|
[
|
|
DiscriminatorSTFT(
|
|
n_filters,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
n_fft=n_ffts[i],
|
|
win_length=win_lengths[i],
|
|
hop_length=hop_lengths[i],
|
|
**kwargs
|
|
)
|
|
for i in range(len(n_ffts))
|
|
]
|
|
)
|
|
self.num_discriminators = len(self.discriminators)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
logits = []
|
|
fmaps = []
|
|
for disc in self.discriminators:
|
|
logit, fmap = disc(x)
|
|
logits.append(logit)
|
|
fmaps.append(fmap)
|
|
return logits, fmaps
|