mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
270 lines
8.1 KiB
Python
270 lines
8.1 KiB
Python
import math
|
|
import random
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
import torch
|
|
from loss import loss_dis, loss_g
|
|
from torch import nn
|
|
from torch.cuda.amp import autocast
|
|
|
|
|
|
class Encodec(nn.Module):
|
|
def __init__(
|
|
self,
|
|
sampling_rate: int,
|
|
target_bandwidths: List[float],
|
|
params: dict,
|
|
encoder: nn.Module,
|
|
quantizer: nn.Module,
|
|
decoder: nn.Module,
|
|
multi_scale_discriminator: nn.Module,
|
|
multi_period_discriminator: nn.Module,
|
|
multi_scale_stft_discriminator: nn.Module,
|
|
cache_generator_outputs: bool = False,
|
|
):
|
|
super(Encodec, self).__init__()
|
|
|
|
self.params = params
|
|
|
|
# setup the generator
|
|
self.sampling_rate = sampling_rate
|
|
self.encoder = encoder
|
|
self.quantizer = quantizer
|
|
self.decoder = decoder
|
|
|
|
self.ratios = encoder.ratios
|
|
self.hop_length = np.prod(self.ratios)
|
|
self.frame_rate = math.ceil(self.sampling_rate / np.prod(self.ratios))
|
|
self.target_bandwidths = target_bandwidths
|
|
|
|
# discriminators
|
|
self.multi_scale_discriminator = multi_scale_discriminator
|
|
self.multi_period_discriminator = multi_period_discriminator
|
|
self.multi_scale_stft_discriminator = multi_scale_stft_discriminator
|
|
|
|
# cache
|
|
self.cache_generator_outputs = cache_generator_outputs
|
|
self._cache = None
|
|
|
|
def _forward_generator(
|
|
self,
|
|
speech: torch.Tensor,
|
|
speech_lengths: torch.Tensor,
|
|
global_step: int,
|
|
return_sample: bool = False,
|
|
):
|
|
"""Perform generator forward.
|
|
|
|
Args:
|
|
speech (Tensor): Speech waveform tensor (B, T_wav).
|
|
speech_lengths (Tensor): Speech length tensor (B,).
|
|
global_step (int): Global step.
|
|
return_sample (bool): Return the generator output.
|
|
|
|
Returns:
|
|
* loss (Tensor): Loss scalar tensor.
|
|
* stats (Dict[str, float]): Statistics to be monitored.
|
|
"""
|
|
# setup
|
|
speech = speech.unsqueeze(1)
|
|
|
|
# calculate generator outputs
|
|
reuse_cache = True
|
|
if not self.cache_generator_outputs or self._cache is None:
|
|
reuse_cache = False
|
|
e = self.encoder(speech)
|
|
bw = random.choice(self.target_bandwidths)
|
|
quantized, codes, bandwidth, commit_loss = self.quantizer(
|
|
e, self.frame_rate, bw
|
|
)
|
|
speech_hat = self.decoder(quantized)
|
|
else:
|
|
speech_hat = self._cache
|
|
|
|
# store cache
|
|
if self.training and self.cache_generator_outputs and not reuse_cache:
|
|
self._cache = speech_hat
|
|
|
|
# calculate discriminator outputs
|
|
y_hat, fmap_hat = self.multi_scale_stft_discriminator(speech_hat.contiguous())
|
|
with torch.no_grad():
|
|
# do not store discriminator gradient in generator turn
|
|
y, fmap = self.multi_scale_stft_discriminator(speech.contiguous())
|
|
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
|
|
speech.contiguous(),
|
|
speech_hat.contiguous(),
|
|
)
|
|
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
|
|
speech.contiguous(),
|
|
speech_hat.contiguous(),
|
|
)
|
|
|
|
# calculate losses
|
|
with autocast(enabled=False):
|
|
loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g(
|
|
commit_loss,
|
|
speech,
|
|
speech_hat,
|
|
fmap,
|
|
fmap_hat,
|
|
y,
|
|
y_hat,
|
|
global_step,
|
|
y_p,
|
|
y_p_hat,
|
|
y_s,
|
|
y_s_hat,
|
|
fmap_p,
|
|
fmap_p_hat,
|
|
fmap_s,
|
|
fmap_s_hat,
|
|
args=self.params,
|
|
)
|
|
|
|
stats = dict(
|
|
generator_loss=loss.item(),
|
|
generator_reconstruction_loss=rec_loss.item(),
|
|
generator_feature_loss=feat_loss.item(),
|
|
generator_adv_loss=adv_loss.item(),
|
|
generator_commit_loss=commit_loss.item(),
|
|
d_weight=d_weight.item(),
|
|
)
|
|
|
|
if return_sample:
|
|
stats["returned_sample"] = (
|
|
speech_hat.cpu(),
|
|
speech.cpu(),
|
|
fmap_hat[0][0].data.cpu(),
|
|
fmap[0][0].data.cpu(),
|
|
)
|
|
|
|
# reset cache
|
|
if reuse_cache or not self.training:
|
|
self._cache = None
|
|
|
|
return loss, stats
|
|
|
|
def _forward_discriminator(
|
|
self,
|
|
speech: torch.Tensor,
|
|
speech_lengths: torch.Tensor,
|
|
global_step: int,
|
|
):
|
|
"""
|
|
Args:
|
|
speech (Tensor): Speech waveform tensor (B, T_wav).
|
|
speech_lengths (Tensor): Speech length tensor (B,).
|
|
global_step (int): Global step.
|
|
|
|
Returns:
|
|
* loss (Tensor): Loss scalar tensor.
|
|
* stats (Dict[str, float]): Statistics to be monitored.
|
|
"""
|
|
# setup
|
|
speech = speech.unsqueeze(1)
|
|
|
|
# calculate generator outputs
|
|
reuse_cache = True
|
|
if not self.cache_generator_outputs or self._cache is None:
|
|
reuse_cache = False
|
|
e = self.encoder(speech)
|
|
bw = random.choice(self.target_bandwidths)
|
|
quantized, codes, bandwidth, commit_loss = self.quantizer(
|
|
e, self.frame_rate, bw
|
|
)
|
|
speech_hat = self.decoder(quantized)
|
|
else:
|
|
speech_hat = self._cache
|
|
|
|
# store cache
|
|
if self.training and self.cache_generator_outputs and not reuse_cache:
|
|
self._cache = speech_hat
|
|
|
|
# calculate discriminator outputs
|
|
y, fmap = self.multi_scale_stft_discriminator(speech.contiguous())
|
|
y_hat, fmap_hat = self.multi_scale_stft_discriminator(
|
|
speech_hat.contiguous().detach()
|
|
)
|
|
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
|
|
speech.contiguous(),
|
|
speech_hat.contiguous().detach(),
|
|
)
|
|
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
|
|
speech.contiguous(),
|
|
speech_hat.contiguous().detach(),
|
|
)
|
|
# calculate losses
|
|
with autocast(enabled=False):
|
|
loss = loss_dis(
|
|
y,
|
|
y_hat,
|
|
fmap,
|
|
fmap_hat,
|
|
y_p,
|
|
y_p_hat,
|
|
fmap_p,
|
|
fmap_p_hat,
|
|
y_s,
|
|
y_s_hat,
|
|
fmap_s,
|
|
fmap_s_hat,
|
|
global_step,
|
|
args=self.params,
|
|
)
|
|
stats = dict(
|
|
discriminator_loss=loss.item(),
|
|
)
|
|
|
|
# reset cache
|
|
if reuse_cache or not self.training:
|
|
self._cache = None
|
|
|
|
return loss, stats
|
|
|
|
def forward(
|
|
self,
|
|
speech: torch.Tensor,
|
|
speech_lengths: torch.Tensor,
|
|
global_step: int,
|
|
return_sample: bool,
|
|
forward_generator: bool,
|
|
):
|
|
if forward_generator:
|
|
return self._forward_generator(
|
|
speech=speech,
|
|
speech_lengths=speech_lengths,
|
|
global_step=global_step,
|
|
return_sample=return_sample,
|
|
)
|
|
else:
|
|
return self._forward_discriminator(
|
|
speech=speech,
|
|
speech_lengths=speech_lengths,
|
|
global_step=global_step,
|
|
)
|
|
|
|
def encode(self, x, target_bw=None, st=None):
|
|
e = self.encoder(x)
|
|
if target_bw is None:
|
|
bw = self.target_bandwidths[-1]
|
|
else:
|
|
bw = target_bw
|
|
if st is None:
|
|
st = 0
|
|
codes = self.quantizer.encode(e, self.frame_rate, bw, st)
|
|
return codes
|
|
|
|
def decode(self, codes):
|
|
quantized = self.quantizer.decode(codes)
|
|
o = self.decoder(quantized)
|
|
return o
|
|
|
|
def inference(self, x, target_bw=None, st=None):
|
|
# setup
|
|
x = x.unsqueeze(1)
|
|
|
|
codes = self.encode(x, target_bw, st)
|
|
o = self.decode(codes)
|
|
return o
|