2024-09-06 21:21:58 +08:00

280 lines
8.6 KiB
Python

import math
import random
from typing import List
import numpy as np
import torch
from loss import loss_dis, loss_g
from torch import nn
from torch.cuda.amp import autocast
class Encodec(nn.Module):
def __init__(
self,
sampling_rate: int,
target_bandwidths: List[float],
params: dict,
encoder: nn.Module,
quantizer: nn.Module,
decoder: nn.Module,
multi_scale_discriminator: nn.Module,
multi_period_discriminator: nn.Module,
multi_scale_stft_discriminator: nn.Module,
cache_generator_outputs: bool = False,
):
super(Encodec, self).__init__()
self.params = params
# setup the generator
self.sampling_rate = sampling_rate
self.encoder = encoder
self.quantizer = quantizer
self.decoder = decoder
self.ratios = encoder.ratios
self.hop_length = np.prod(self.ratios)
self.frame_rate = math.ceil(self.sampling_rate / np.prod(self.ratios))
self.target_bandwidths = target_bandwidths
# discriminators
self.multi_scale_discriminator = multi_scale_discriminator
self.multi_period_discriminator = multi_period_discriminator
self.multi_scale_stft_discriminator = multi_scale_stft_discriminator
# cache
self.cache_generator_outputs = cache_generator_outputs
self._cache = None
def _forward_generator(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
global_step: int,
return_sample: bool = False,
):
"""Perform generator forward.
Args:
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
global_step (int): Global step.
return_sample (bool): Return the generator output.
Returns:
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
"""
# setup
speech = speech.unsqueeze(1)
# calculate generator outputs
reuse_cache = True
if not self.cache_generator_outputs or self._cache is None:
reuse_cache = False
e = self.encoder(speech)
index = torch.tensor(
random.randint(0, len(self.target_bandwidths) - 1),
device=speech.device,
)
if torch.distributed.is_initialized():
torch.distributed.broadcast(index, src=0)
bw = self.target_bandwidths[index.item()]
quantized, codes, bandwidth, commit_loss = self.quantizer(
e, self.frame_rate, bw
)
speech_hat = self.decoder(quantized)
else:
speech_hat = self._cache
# store cache
if self.training and self.cache_generator_outputs and not reuse_cache:
self._cache = speech_hat
# calculate discriminator outputs
y_hat, fmap_hat = self.multi_scale_stft_discriminator(speech_hat.contiguous())
with torch.no_grad():
# do not store discriminator gradient in generator turn
y, fmap = self.multi_scale_stft_discriminator(speech.contiguous())
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
speech.contiguous(),
speech_hat.contiguous(),
)
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
speech.contiguous(),
speech_hat.contiguous(),
)
# calculate losses
with autocast(enabled=False):
loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g(
commit_loss,
speech,
speech_hat,
fmap,
fmap_hat,
y,
y_hat,
global_step,
y_p,
y_p_hat,
y_s,
y_s_hat,
fmap_p,
fmap_p_hat,
fmap_s,
fmap_s_hat,
args=self.params,
)
stats = dict(
generator_loss=loss.item(),
generator_reconstruction_loss=rec_loss.item(),
generator_feature_loss=feat_loss.item(),
generator_adv_loss=adv_loss.item(),
generator_commit_loss=commit_loss.item(),
d_weight=d_weight.item(),
)
if return_sample:
stats["returned_sample"] = (
speech_hat.cpu(),
speech.cpu(),
fmap_hat[0][0].data.cpu(),
fmap[0][0].data.cpu(),
)
# reset cache
if reuse_cache or not self.training:
self._cache = None
return loss, stats
def _forward_discriminator(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
global_step: int,
):
"""
Args:
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
global_step (int): Global step.
Returns:
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
"""
# setup
speech = speech.unsqueeze(1)
# calculate generator outputs
reuse_cache = True
if not self.cache_generator_outputs or self._cache is None:
reuse_cache = False
e = self.encoder(speech)
index = torch.tensor(
random.randint(0, len(self.target_bandwidths) - 1),
device=speech.device,
)
if torch.distributed.is_initialized():
torch.distributed.broadcast(index, src=0)
bw = self.target_bandwidths[index.item()]
quantized, codes, bandwidth, commit_loss = self.quantizer(
e, self.frame_rate, bw
)
speech_hat = self.decoder(quantized)
else:
speech_hat = self._cache
# store cache
if self.training and self.cache_generator_outputs and not reuse_cache:
self._cache = speech_hat
# calculate discriminator outputs
y, fmap = self.multi_scale_stft_discriminator(speech.contiguous())
y_hat, fmap_hat = self.multi_scale_stft_discriminator(
speech_hat.contiguous().detach()
)
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
speech.contiguous(),
speech_hat.contiguous().detach(),
)
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
speech.contiguous(),
speech_hat.contiguous().detach(),
)
# calculate losses
with autocast(enabled=False):
loss = loss_dis(
y,
y_hat,
fmap,
fmap_hat,
y_p,
y_p_hat,
fmap_p,
fmap_p_hat,
y_s,
y_s_hat,
fmap_s,
fmap_s_hat,
global_step,
args=self.params,
)
stats = dict(
discriminator_loss=loss.item(),
)
# reset cache
if reuse_cache or not self.training:
self._cache = None
return loss, stats
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
global_step: int,
return_sample: bool,
forward_generator: bool,
):
if forward_generator:
return self._forward_generator(
speech=speech,
speech_lengths=speech_lengths,
global_step=global_step,
return_sample=return_sample,
)
else:
return self._forward_discriminator(
speech=speech,
speech_lengths=speech_lengths,
global_step=global_step,
)
def encode(self, x, target_bw=None, st=None):
e = self.encoder(x)
if target_bw is None:
bw = self.target_bandwidths[-1]
else:
bw = target_bw
if st is None:
st = 0
codes = self.quantizer.encode(e, self.frame_rate, bw, st)
return codes
def decode(self, codes):
quantized = self.quantizer.decode(codes)
o = self.decoder(quantized)
return o
def inference(self, x, target_bw=None, st=None):
# setup
x = x.unsqueeze(1)
codes = self.encode(x, target_bw, st)
o = self.decode(codes)
return o