2024-12-13 19:39:55 +08:00

49 lines
1.4 KiB
Python

import logging
import torch
from discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
from generator import Generator
from loss import (
DiscriminatorLoss,
GeneratorLoss,
FeatureMatchingLoss,
MelSpecReconstructionLoss,
)
class Vocos(torch.nn.Module):
def __init__(
self,
feature_dim: int = 80,
dim: int = 512,
n_fft: int = 1024,
hop_length: int = 256,
intermediate_dim: int = 1536,
num_layers: int = 8,
padding: str = "none",
sample_rate: int = 24000,
max_seconds: int = 60,
):
super(Vocos, self).__init__()
self.generator = Generator(
feature_dim=feature_dim,
dim=dim,
n_fft=n_fft,
hop_length=hop_length,
num_layers=num_layers,
intermediate_dim=intermediate_dim,
padding=padding,
max_samples=int(sample_rate * max_seconds),
)
self.mpd = MultiPeriodDiscriminator()
self.mrd = MultiResolutionDiscriminator()
self.disc_loss = DiscriminatorLoss()
self.gen_loss = GeneratorLoss()
self.feat_matching_loss = FeatureMatchingLoss()
self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate)
def forward(self, features: torch.Tensor):
audio = self.generator(features)
return audio