mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
360 lines
13 KiB
Python
360 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright 2024 The Chinese University of HK (Author: Zengrui Jin)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import math
|
|
import random
|
|
from typing import List, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
from loss import (
|
|
DiscriminatorAdversarialLoss,
|
|
FeatureLoss,
|
|
GeneratorAdversarialLoss,
|
|
MelSpectrogramReconstructionLoss,
|
|
WavReconstructionLoss,
|
|
)
|
|
from torch import nn
|
|
from torch.cuda.amp import autocast
|
|
|
|
|
|
class Encodec(nn.Module):
|
|
def __init__(
|
|
self,
|
|
sampling_rate: int,
|
|
target_bandwidths: List[float],
|
|
params: dict,
|
|
encoder: nn.Module,
|
|
quantizer: nn.Module,
|
|
decoder: nn.Module,
|
|
multi_scale_discriminator: nn.Module,
|
|
multi_period_discriminator: Optional[nn.Module] = None,
|
|
multi_scale_stft_discriminator: Optional[nn.Module] = None,
|
|
cache_generator_outputs: bool = False,
|
|
):
|
|
super(Encodec, self).__init__()
|
|
|
|
self.params = params
|
|
|
|
# setup the generator
|
|
self.sampling_rate = sampling_rate
|
|
self.encoder = encoder
|
|
self.quantizer = quantizer
|
|
self.decoder = decoder
|
|
|
|
self.ratios = encoder.ratios
|
|
self.hop_length = np.prod(self.ratios)
|
|
self.frame_rate = math.ceil(self.sampling_rate / np.prod(self.ratios))
|
|
self.target_bandwidths = target_bandwidths
|
|
|
|
# discriminators
|
|
self.multi_scale_discriminator = multi_scale_discriminator
|
|
self.multi_period_discriminator = multi_period_discriminator
|
|
self.multi_scale_stft_discriminator = multi_scale_stft_discriminator
|
|
|
|
# cache
|
|
self.cache_generator_outputs = cache_generator_outputs
|
|
self._cache = None
|
|
|
|
# construct loss functions
|
|
self.generator_adversarial_loss = GeneratorAdversarialLoss(
|
|
average_by_discriminators=True, loss_type="hinge"
|
|
)
|
|
self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss(
|
|
average_by_discriminators=True, loss_type="hinge"
|
|
)
|
|
self.feature_match_loss = FeatureLoss()
|
|
self.wav_reconstruction_loss = WavReconstructionLoss()
|
|
self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss(
|
|
sampling_rate=self.sampling_rate
|
|
)
|
|
|
|
def _forward_generator(
|
|
self,
|
|
speech: torch.Tensor,
|
|
speech_lengths: torch.Tensor,
|
|
return_sample: bool = False,
|
|
):
|
|
"""Perform generator forward.
|
|
|
|
Args:
|
|
speech (Tensor): Speech waveform tensor (B, T_wav).
|
|
speech_lengths (Tensor): Speech length tensor (B,).
|
|
return_sample (bool): Return the generator output.
|
|
|
|
Returns:
|
|
* loss (Tensor): Loss scalar tensor.
|
|
* stats (Dict[str, float]): Statistics to be monitored.
|
|
"""
|
|
# setup
|
|
speech = speech.unsqueeze(1)
|
|
|
|
# calculate generator outputs
|
|
reuse_cache = True
|
|
if not self.cache_generator_outputs or self._cache is None:
|
|
reuse_cache = False
|
|
e = self.encoder(speech)
|
|
index = torch.tensor(
|
|
random.randint(0, len(self.target_bandwidths) - 1),
|
|
device=speech.device,
|
|
)
|
|
if torch.distributed.is_initialized():
|
|
torch.distributed.broadcast(index, src=0)
|
|
bw = self.target_bandwidths[index.item()]
|
|
quantized, codes, bandwidth, commit_loss = self.quantizer(
|
|
e, self.frame_rate, bw
|
|
)
|
|
speech_hat = self.decoder(quantized)
|
|
else:
|
|
speech_hat = self._cache
|
|
# store cache
|
|
if self.training and self.cache_generator_outputs and not reuse_cache:
|
|
self._cache = speech_hat
|
|
|
|
# calculate discriminator outputs
|
|
y_hat, fmap_hat = self.multi_scale_stft_discriminator(speech_hat.contiguous())
|
|
with torch.no_grad():
|
|
# do not store discriminator gradient in generator turn
|
|
y, fmap = self.multi_scale_stft_discriminator(speech.contiguous())
|
|
|
|
gen_period_adv_loss = torch.tensor(0.0)
|
|
feature_period_loss = torch.tensor(0.0)
|
|
if self.multi_period_discriminator is not None:
|
|
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
|
|
speech.contiguous(),
|
|
speech_hat.contiguous(),
|
|
)
|
|
|
|
gen_scale_adv_loss = torch.tensor(0.0)
|
|
feature_scale_loss = torch.tensor(0.0)
|
|
if self.multi_scale_discriminator is not None:
|
|
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
|
|
speech.contiguous(),
|
|
speech_hat.contiguous(),
|
|
)
|
|
|
|
# calculate losses
|
|
with autocast(enabled=False):
|
|
gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat)
|
|
|
|
if self.multi_period_discriminator is not None:
|
|
gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat)
|
|
if self.multi_scale_discriminator is not None:
|
|
gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat)
|
|
|
|
feature_stft_loss = self.feature_match_loss(feats=fmap, feats_hat=fmap_hat)
|
|
|
|
if self.multi_period_discriminator is not None:
|
|
feature_period_loss = self.feature_match_loss(
|
|
feats=fmap_p, feats_hat=fmap_p_hat
|
|
)
|
|
if self.multi_scale_discriminator is not None:
|
|
feature_scale_loss = self.feature_match_loss(
|
|
feats=fmap_s, feats_hat=fmap_s_hat
|
|
)
|
|
|
|
wav_reconstruction_loss = self.wav_reconstruction_loss(
|
|
x=speech, x_hat=speech_hat
|
|
)
|
|
mel_reconstruction_loss = self.mel_reconstruction_loss(
|
|
x=speech, x_hat=speech_hat
|
|
)
|
|
|
|
stats = dict(
|
|
generator_wav_reconstruction_loss=wav_reconstruction_loss.item(),
|
|
generator_mel_reconstruction_loss=mel_reconstruction_loss.item(),
|
|
generator_feature_stft_loss=feature_stft_loss.item(),
|
|
generator_feature_period_loss=feature_period_loss.item(),
|
|
generator_feature_scale_loss=feature_scale_loss.item(),
|
|
generator_stft_adv_loss=gen_stft_adv_loss.item(),
|
|
generator_period_adv_loss=gen_period_adv_loss.item(),
|
|
generator_scale_adv_loss=gen_scale_adv_loss.item(),
|
|
generator_commit_loss=commit_loss.item(),
|
|
)
|
|
|
|
if return_sample:
|
|
stats["returned_sample"] = (
|
|
speech_hat.cpu(),
|
|
speech.cpu(),
|
|
fmap_hat[0][0].data.cpu(),
|
|
fmap[0][0].data.cpu(),
|
|
)
|
|
|
|
# reset cache
|
|
if reuse_cache or not self.training:
|
|
self._cache = None
|
|
return (
|
|
commit_loss,
|
|
gen_stft_adv_loss,
|
|
gen_period_adv_loss,
|
|
gen_scale_adv_loss,
|
|
feature_stft_loss,
|
|
feature_period_loss,
|
|
feature_scale_loss,
|
|
wav_reconstruction_loss,
|
|
mel_reconstruction_loss,
|
|
stats,
|
|
)
|
|
|
|
def _forward_discriminator(
|
|
self,
|
|
speech: torch.Tensor,
|
|
speech_lengths: torch.Tensor,
|
|
):
|
|
"""
|
|
Args:
|
|
speech (Tensor): Speech waveform tensor (B, T_wav).
|
|
speech_lengths (Tensor): Speech length tensor (B,).
|
|
|
|
Returns:
|
|
* loss (Tensor): Loss scalar tensor.
|
|
* stats (Dict[str, float]): Statistics to be monitored.
|
|
"""
|
|
# setup
|
|
speech = speech.unsqueeze(1)
|
|
|
|
# calculate generator outputs
|
|
reuse_cache = True
|
|
if not self.cache_generator_outputs or self._cache is None:
|
|
reuse_cache = False
|
|
e = self.encoder(speech)
|
|
index = torch.tensor(
|
|
random.randint(0, len(self.target_bandwidths) - 1),
|
|
device=speech.device,
|
|
)
|
|
if torch.distributed.is_initialized():
|
|
torch.distributed.broadcast(index, src=0)
|
|
bw = self.target_bandwidths[index.item()]
|
|
quantized, codes, bandwidth, commit_loss = self.quantizer(
|
|
e, self.frame_rate, bw
|
|
)
|
|
speech_hat = self.decoder(quantized)
|
|
else:
|
|
speech_hat = self._cache
|
|
|
|
# store cache
|
|
if self.training and self.cache_generator_outputs and not reuse_cache:
|
|
self._cache = speech_hat
|
|
|
|
# calculate discriminator outputs
|
|
y, fmap = self.multi_scale_stft_discriminator(speech.contiguous())
|
|
y_hat, fmap_hat = self.multi_scale_stft_discriminator(
|
|
speech_hat.contiguous().detach()
|
|
)
|
|
|
|
disc_period_real_adv_loss = torch.tensor(0.0)
|
|
disc_period_fake_adv_loss = torch.tensor(0.0)
|
|
if self.multi_period_discriminator is not None:
|
|
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
|
|
speech.contiguous(),
|
|
speech_hat.contiguous().detach(),
|
|
)
|
|
|
|
disc_scale_real_adv_loss = torch.tensor(0.0)
|
|
disc_scale_fake_adv_loss = torch.tensor(0.0)
|
|
if self.multi_scale_discriminator is not None:
|
|
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
|
|
speech.contiguous(),
|
|
speech_hat.contiguous().detach(),
|
|
)
|
|
# calculate losses
|
|
with autocast(enabled=False):
|
|
(
|
|
disc_stft_real_adv_loss,
|
|
disc_stft_fake_adv_loss,
|
|
) = self.discriminator_adversarial_loss(outputs=y, outputs_hat=y_hat)
|
|
if self.multi_period_discriminator is not None:
|
|
(
|
|
disc_period_real_adv_loss,
|
|
disc_period_fake_adv_loss,
|
|
) = self.discriminator_adversarial_loss(
|
|
outputs=y_p, outputs_hat=y_p_hat
|
|
)
|
|
if self.multi_scale_discriminator is not None:
|
|
(
|
|
disc_scale_real_adv_loss,
|
|
disc_scale_fake_adv_loss,
|
|
) = self.discriminator_adversarial_loss(
|
|
outputs=y_s, outputs_hat=y_s_hat
|
|
)
|
|
|
|
stats = dict(
|
|
discriminator_stft_real_adv_loss=disc_stft_real_adv_loss.item(),
|
|
discriminator_period_real_adv_loss=disc_period_real_adv_loss.item(),
|
|
discriminator_scale_real_adv_loss=disc_scale_real_adv_loss.item(),
|
|
discriminator_stft_fake_adv_loss=disc_stft_fake_adv_loss.item(),
|
|
discriminator_period_fake_adv_loss=disc_period_fake_adv_loss.item(),
|
|
discriminator_scale_fake_adv_loss=disc_scale_fake_adv_loss.item(),
|
|
)
|
|
|
|
# reset cache
|
|
if reuse_cache or not self.training:
|
|
self._cache = None
|
|
|
|
return (
|
|
disc_stft_real_adv_loss,
|
|
disc_stft_fake_adv_loss,
|
|
disc_period_real_adv_loss,
|
|
disc_period_fake_adv_loss,
|
|
disc_scale_real_adv_loss,
|
|
disc_scale_fake_adv_loss,
|
|
stats,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
speech: torch.Tensor,
|
|
speech_lengths: torch.Tensor,
|
|
return_sample: bool,
|
|
forward_generator: bool,
|
|
):
|
|
if forward_generator:
|
|
return self._forward_generator(
|
|
speech=speech,
|
|
speech_lengths=speech_lengths,
|
|
return_sample=return_sample,
|
|
)
|
|
else:
|
|
return self._forward_discriminator(
|
|
speech=speech,
|
|
speech_lengths=speech_lengths,
|
|
)
|
|
|
|
def encode(self, x, target_bw=None, st=None):
|
|
e = self.encoder(x)
|
|
if target_bw is None:
|
|
bw = self.target_bandwidths[-1]
|
|
else:
|
|
bw = target_bw
|
|
if st is None:
|
|
st = 0
|
|
codes = self.quantizer.encode(e, self.frame_rate, bw, st)
|
|
return codes
|
|
|
|
def decode(self, codes):
|
|
quantized = self.quantizer.decode(codes)
|
|
x_hat = self.decoder(quantized)
|
|
return x_hat
|
|
|
|
def inference(self, x, target_bw=None, st=None):
|
|
# setup
|
|
x = x.unsqueeze(1)
|
|
|
|
codes = self.encode(x, target_bw, st)
|
|
x_hat = self.decode(codes)
|
|
return codes, x_hat
|