From ee718f1da13957a2406522e254c78bd50d631a54 Mon Sep 17 00:00:00 2001 From: jinzr Date: Fri, 1 Dec 2023 00:12:35 +0800 Subject: [PATCH] removed redundant files --- egs/vctk/TTS/vits/monotonic_align/__init__.py | 81 --- egs/vctk/TTS/vits/monotonic_align/core.pyx | 51 -- egs/vctk/TTS/vits/monotonic_align/setup.py | 31 - egs/vctk/TTS/vits/transform.py | 218 ------- egs/vctk/TTS/vits/utils.py | 265 -------- egs/vctk/TTS/vits/vits.py | 609 ------------------ egs/vctk/TTS/vits/wavenet.py | 349 ---------- 7 files changed, 1604 deletions(-) delete mode 100644 egs/vctk/TTS/vits/monotonic_align/__init__.py delete mode 100644 egs/vctk/TTS/vits/monotonic_align/core.pyx delete mode 100644 egs/vctk/TTS/vits/monotonic_align/setup.py delete mode 100644 egs/vctk/TTS/vits/transform.py delete mode 100644 egs/vctk/TTS/vits/utils.py delete mode 100644 egs/vctk/TTS/vits/vits.py delete mode 100644 egs/vctk/TTS/vits/wavenet.py diff --git a/egs/vctk/TTS/vits/monotonic_align/__init__.py b/egs/vctk/TTS/vits/monotonic_align/__init__.py deleted file mode 100644 index 2b35654f5..000000000 --- a/egs/vctk/TTS/vits/monotonic_align/__init__.py +++ /dev/null @@ -1,81 +0,0 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/__init__.py - -"""Maximum path calculation module. - -This code is based on https://github.com/jaywalnut310/vits. - -""" - -import warnings - -import numpy as np -import torch -from numba import njit, prange - -try: - from .core import maximum_path_c - - is_cython_avalable = True -except ImportError: - is_cython_avalable = False - warnings.warn( - "Cython version is not available. Fallback to 'EXPERIMETAL' numba version. " - "If you want to use the cython version, please build it as follows: " - "`cd espnet2/gan_tts/vits/monotonic_align; python setup.py build_ext --inplace`" - ) - - -def maximum_path(neg_x_ent: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: - """Calculate maximum path. - - Args: - neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text). - attn_mask (Tensor): Attention mask (B, T_feats, T_text). - - Returns: - Tensor: Maximum path tensor (B, T_feats, T_text). - - """ - device, dtype = neg_x_ent.device, neg_x_ent.dtype - neg_x_ent = neg_x_ent.cpu().numpy().astype(np.float32) - path = np.zeros(neg_x_ent.shape, dtype=np.int32) - t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32) - t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32) - if is_cython_avalable: - maximum_path_c(path, neg_x_ent, t_t_max, t_s_max) - else: - maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max) - - return torch.from_numpy(path).to(device=device, dtype=dtype) - - -@njit -def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf): - """Calculate a single maximum path with numba.""" - index = t_x - 1 - for y in range(t_y): - for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): - if x == y: - v_cur = max_neg_val - else: - v_cur = value[y - 1, x] - if x == 0: - if y == 0: - v_prev = 0.0 - else: - v_prev = max_neg_val - else: - v_prev = value[y - 1, x - 1] - value[y, x] += max(v_prev, v_cur) - - for y in range(t_y - 1, -1, -1): - path[y, index] = 1 - if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): - index = index - 1 - - -@njit(parallel=True) -def maximum_path_numba(paths, values, t_ys, t_xs): - """Calculate batch maximum path with numba.""" - for i in prange(paths.shape[0]): - maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i]) diff --git a/egs/vctk/TTS/vits/monotonic_align/core.pyx b/egs/vctk/TTS/vits/monotonic_align/core.pyx deleted file mode 100644 index c02c2d02e..000000000 --- a/egs/vctk/TTS/vits/monotonic_align/core.pyx +++ /dev/null @@ -1,51 +0,0 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/core.pyx - -"""Maximum path calculation module with cython optimization. - -This code is copied from https://github.com/jaywalnut310/vits and modifed code format. - -""" - -cimport cython - -from cython.parallel import prange - - -@cython.boundscheck(False) -@cython.wraparound(False) -cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: - cdef int x - cdef int y - cdef float v_prev - cdef float v_cur - cdef float tmp - cdef int index = t_x - 1 - - for y in range(t_y): - for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): - if x == y: - v_cur = max_neg_val - else: - v_cur = value[y - 1, x] - if x == 0: - if y == 0: - v_prev = 0.0 - else: - v_prev = max_neg_val - else: - v_prev = value[y - 1, x - 1] - value[y, x] += max(v_prev, v_cur) - - for y in range(t_y - 1, -1, -1): - path[y, index] = 1 - if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): - index = index - 1 - - -@cython.boundscheck(False) -@cython.wraparound(False) -cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil: - cdef int b = paths.shape[0] - cdef int i - for i in prange(b, nogil=True): - maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) diff --git a/egs/vctk/TTS/vits/monotonic_align/setup.py b/egs/vctk/TTS/vits/monotonic_align/setup.py deleted file mode 100644 index 33d75e176..000000000 --- a/egs/vctk/TTS/vits/monotonic_align/setup.py +++ /dev/null @@ -1,31 +0,0 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py -"""Setup cython code.""" - -from Cython.Build import cythonize -from setuptools import Extension, setup -from setuptools.command.build_ext import build_ext as _build_ext - - -class build_ext(_build_ext): - """Overwrite build_ext.""" - - def finalize_options(self): - """Prevent numpy from thinking it is still in its setup process.""" - _build_ext.finalize_options(self) - __builtins__.__NUMPY_SETUP__ = False - import numpy - - self.include_dirs.append(numpy.get_include()) - - -exts = [ - Extension( - name="core", - sources=["core.pyx"], - ) -] -setup( - name="monotonic_align", - ext_modules=cythonize(exts, language_level=3), - cmdclass={"build_ext": build_ext}, -) diff --git a/egs/vctk/TTS/vits/transform.py b/egs/vctk/TTS/vits/transform.py deleted file mode 100644 index c20d13130..000000000 --- a/egs/vctk/TTS/vits/transform.py +++ /dev/null @@ -1,218 +0,0 @@ -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py - -"""Flow-related transformation. - -This code is derived from https://github.com/bayesiains/nflows. - -""" - -import numpy as np -import torch -from torch.nn import functional as F - -DEFAULT_MIN_BIN_WIDTH = 1e-3 -DEFAULT_MIN_BIN_HEIGHT = 1e-3 -DEFAULT_MIN_DERIVATIVE = 1e-3 - - -# TODO(kan-bayashi): Documentation and type hint -def piecewise_rational_quadratic_transform( - inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - tails=None, - tail_bound=1.0, - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE, -): - if tails is None: - spline_fn = rational_quadratic_spline - spline_kwargs = {} - else: - spline_fn = unconstrained_rational_quadratic_spline - spline_kwargs = {"tails": tails, "tail_bound": tail_bound} - - outputs, logabsdet = spline_fn( - inputs=inputs, - unnormalized_widths=unnormalized_widths, - unnormalized_heights=unnormalized_heights, - unnormalized_derivatives=unnormalized_derivatives, - inverse=inverse, - min_bin_width=min_bin_width, - min_bin_height=min_bin_height, - min_derivative=min_derivative, - **spline_kwargs - ) - return outputs, logabsdet - - -# TODO(kan-bayashi): Documentation and type hint -def unconstrained_rational_quadratic_spline( - inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - tails="linear", - tail_bound=1.0, - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE, -): - inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) - outside_interval_mask = ~inside_interval_mask - - outputs = torch.zeros_like(inputs) - logabsdet = torch.zeros_like(inputs) - - if tails == "linear": - unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) - constant = np.log(np.exp(1 - min_derivative) - 1) - unnormalized_derivatives[..., 0] = constant - unnormalized_derivatives[..., -1] = constant - - outputs[outside_interval_mask] = inputs[outside_interval_mask] - logabsdet[outside_interval_mask] = 0 - else: - raise RuntimeError("{} tails are not implemented.".format(tails)) - - ( - outputs[inside_interval_mask], - logabsdet[inside_interval_mask], - ) = rational_quadratic_spline( - inputs=inputs[inside_interval_mask], - unnormalized_widths=unnormalized_widths[inside_interval_mask, :], - unnormalized_heights=unnormalized_heights[inside_interval_mask, :], - unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], - inverse=inverse, - left=-tail_bound, - right=tail_bound, - bottom=-tail_bound, - top=tail_bound, - min_bin_width=min_bin_width, - min_bin_height=min_bin_height, - min_derivative=min_derivative, - ) - - return outputs, logabsdet - - -# TODO(kan-bayashi): Documentation and type hint -def rational_quadratic_spline( - inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - left=0.0, - right=1.0, - bottom=0.0, - top=1.0, - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE, -): - if torch.min(inputs) < left or torch.max(inputs) > right: - raise ValueError("Input to a transform is not within its domain") - - num_bins = unnormalized_widths.shape[-1] - - if min_bin_width * num_bins > 1.0: - raise ValueError("Minimal bin width too large for the number of bins") - if min_bin_height * num_bins > 1.0: - raise ValueError("Minimal bin height too large for the number of bins") - - widths = F.softmax(unnormalized_widths, dim=-1) - widths = min_bin_width + (1 - min_bin_width * num_bins) * widths - cumwidths = torch.cumsum(widths, dim=-1) - cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) - cumwidths = (right - left) * cumwidths + left - cumwidths[..., 0] = left - cumwidths[..., -1] = right - widths = cumwidths[..., 1:] - cumwidths[..., :-1] - - derivatives = min_derivative + F.softplus(unnormalized_derivatives) - - heights = F.softmax(unnormalized_heights, dim=-1) - heights = min_bin_height + (1 - min_bin_height * num_bins) * heights - cumheights = torch.cumsum(heights, dim=-1) - cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) - cumheights = (top - bottom) * cumheights + bottom - cumheights[..., 0] = bottom - cumheights[..., -1] = top - heights = cumheights[..., 1:] - cumheights[..., :-1] - - if inverse: - bin_idx = _searchsorted(cumheights, inputs)[..., None] - else: - bin_idx = _searchsorted(cumwidths, inputs)[..., None] - - input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] - input_bin_widths = widths.gather(-1, bin_idx)[..., 0] - - input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] - delta = heights / widths - input_delta = delta.gather(-1, bin_idx)[..., 0] - - input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] - input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] - - input_heights = heights.gather(-1, bin_idx)[..., 0] - - if inverse: - a = (inputs - input_cumheights) * ( - input_derivatives + input_derivatives_plus_one - 2 * input_delta - ) + input_heights * (input_delta - input_derivatives) - b = input_heights * input_derivatives - (inputs - input_cumheights) * ( - input_derivatives + input_derivatives_plus_one - 2 * input_delta - ) - c = -input_delta * (inputs - input_cumheights) - - discriminant = b.pow(2) - 4 * a * c - assert (discriminant >= 0).all() - - root = (2 * c) / (-b - torch.sqrt(discriminant)) - outputs = root * input_bin_widths + input_cumwidths - - theta_one_minus_theta = root * (1 - root) - denominator = input_delta + ( - (input_derivatives + input_derivatives_plus_one - 2 * input_delta) - * theta_one_minus_theta - ) - derivative_numerator = input_delta.pow(2) * ( - input_derivatives_plus_one * root.pow(2) - + 2 * input_delta * theta_one_minus_theta - + input_derivatives * (1 - root).pow(2) - ) - logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) - - return outputs, -logabsdet - else: - theta = (inputs - input_cumwidths) / input_bin_widths - theta_one_minus_theta = theta * (1 - theta) - - numerator = input_heights * ( - input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta - ) - denominator = input_delta + ( - (input_derivatives + input_derivatives_plus_one - 2 * input_delta) - * theta_one_minus_theta - ) - outputs = input_cumheights + numerator / denominator - - derivative_numerator = input_delta.pow(2) * ( - input_derivatives_plus_one * theta.pow(2) - + 2 * input_delta * theta_one_minus_theta - + input_derivatives * (1 - theta).pow(2) - ) - logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) - - return outputs, logabsdet - - -def _searchsorted(bin_locations, inputs, eps=1e-6): - bin_locations[..., -1] += eps - return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 diff --git a/egs/vctk/TTS/vits/utils.py b/egs/vctk/TTS/vits/utils.py deleted file mode 100644 index 12b2d6b81..000000000 --- a/egs/vctk/TTS/vits/utils.py +++ /dev/null @@ -1,265 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union -import collections -import logging - -import torch -import torch.nn as nn -import torch.distributed as dist -from lhotse.dataset.sampling.base import CutSampler -from pathlib import Path -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer -from torch.utils.tensorboard import SummaryWriter - - -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py -def get_random_segments( - x: torch.Tensor, - x_lengths: torch.Tensor, - segment_size: int, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Get random segments. - - Args: - x (Tensor): Input tensor (B, C, T). - x_lengths (Tensor): Length tensor (B,). - segment_size (int): Segment size. - - Returns: - Tensor: Segmented tensor (B, C, segment_size). - Tensor: Start index tensor (B,). - - """ - b, c, t = x.size() - max_start_idx = x_lengths - segment_size - max_start_idx[max_start_idx < 0] = 0 - start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to( - dtype=torch.long, - ) - segments = get_segments(x, start_idxs, segment_size) - - return segments, start_idxs - - -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py -def get_segments( - x: torch.Tensor, - start_idxs: torch.Tensor, - segment_size: int, -) -> torch.Tensor: - """Get segments. - - Args: - x (Tensor): Input tensor (B, C, T). - start_idxs (Tensor): Start index tensor (B,). - segment_size (int): Segment size. - - Returns: - Tensor: Segmented tensor (B, C, segment_size). - - """ - b, c, t = x.size() - segments = x.new_zeros(b, c, segment_size) - for i, start_idx in enumerate(start_idxs): - segments[i] = x[i, :, start_idx : start_idx + segment_size] - return segments - - -# from https://github.com/jaywalnut310/vit://github.com/jaywalnut310/vits/blob/main/commons.py -def intersperse(sequence, item=0): - result = [item] * (len(sequence) * 2 + 1) - result[1::2] = sequence - return result - - -# from https://github.com/jaywalnut310/vits/blob/main/utils.py -MATPLOTLIB_FLAG = False - - -def plot_feature(spectrogram): - global MATPLOTLIB_FLAG - if not MATPLOTLIB_FLAG: - import matplotlib - - matplotlib.use("Agg") - MATPLOTLIB_FLAG = True - mpl_logger = logging.getLogger("matplotlib") - mpl_logger.setLevel(logging.WARNING) - import matplotlib.pylab as plt - import numpy as np - - fig, ax = plt.subplots(figsize=(10, 2)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") - plt.colorbar(im, ax=ax) - plt.xlabel("Frames") - plt.ylabel("Channels") - plt.tight_layout() - - fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) - plt.close() - return data - - -class MetricsTracker(collections.defaultdict): - def __init__(self): - # Passing the type 'int' to the base-class constructor - # makes undefined items default to int() which is zero. - # This class will play a role as metrics tracker. - # It can record many metrics, including but not limited to loss. - super(MetricsTracker, self).__init__(int) - - def __add__(self, other: "MetricsTracker") -> "MetricsTracker": - ans = MetricsTracker() - for k, v in self.items(): - ans[k] = v - for k, v in other.items(): - ans[k] = ans[k] + v - return ans - - def __mul__(self, alpha: float) -> "MetricsTracker": - ans = MetricsTracker() - for k, v in self.items(): - ans[k] = v * alpha - return ans - - def __str__(self) -> str: - ans = "" - for k, v in self.norm_items(): - norm_value = "%.4g" % v - ans += str(k) + "=" + str(norm_value) + ", " - samples = "%.2f" % self["samples"] - ans += "over " + str(samples) + " samples." - return ans - - def norm_items(self) -> List[Tuple[str, float]]: - """ - Returns a list of pairs, like: - [('loss_1', 0.1), ('loss_2', 0.07)] - """ - samples = self["samples"] if "samples" in self else 1 - ans = [] - for k, v in self.items(): - if k == "samples": - continue - norm_value = float(v) / samples - ans.append((k, norm_value)) - return ans - - def reduce(self, device): - """ - Reduce using torch.distributed, which I believe ensures that - all processes get the total. - """ - keys = sorted(self.keys()) - s = torch.tensor([float(self[k]) for k in keys], device=device) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - for k, v in zip(keys, s.cpu().tolist()): - self[k] = v - - def write_summary( - self, - tb_writer: SummaryWriter, - prefix: str, - batch_idx: int, - ) -> None: - """Add logging information to a TensorBoard writer. - - Args: - tb_writer: a TensorBoard writer - prefix: a prefix for the name of the loss, e.g. "train/valid_", - or "train/current_" - batch_idx: The current batch index, used as the x-axis of the plot. - """ - for k, v in self.norm_items(): - tb_writer.add_scalar(prefix + k, v, batch_idx) - - -# checkpoint saving and loading -LRSchedulerType = torch.optim.lr_scheduler._LRScheduler - - -def save_checkpoint( - filename: Path, - model: Union[nn.Module, DDP], - params: Optional[Dict[str, Any]] = None, - optimizer_g: Optional[Optimizer] = None, - optimizer_d: Optional[Optimizer] = None, - scheduler_g: Optional[LRSchedulerType] = None, - scheduler_d: Optional[LRSchedulerType] = None, - scaler: Optional[GradScaler] = None, - sampler: Optional[CutSampler] = None, - rank: int = 0, -) -> None: - """Save training information to a file. - - Args: - filename: - The checkpoint filename. - model: - The model to be saved. We only save its `state_dict()`. - model_avg: - The stored model averaged from the start of training. - params: - User defined parameters, e.g., epoch, loss. - optimizer_g: - The optimizer for generator used in the training. - Its `state_dict` will be saved. - optimizer_d: - The optimizer for discriminator used in the training. - Its `state_dict` will be saved. - scheduler_g: - The learning rate scheduler for generator used in the training. - Its `state_dict` will be saved. - scheduler_d: - The learning rate scheduler for discriminator used in the training. - Its `state_dict` will be saved. - scalar: - The GradScaler to be saved. We only save its `state_dict()`. - rank: - Used in DDP. We save checkpoint only for the node whose rank is 0. - Returns: - Return None. - """ - if rank != 0: - return - - logging.info(f"Saving checkpoint to {filename}") - - if isinstance(model, DDP): - model = model.module - - checkpoint = { - "model": model.state_dict(), - "optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None, - "optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None, - "scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None, - "scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None, - "grad_scaler": scaler.state_dict() if scaler is not None else None, - "sampler": sampler.state_dict() if sampler is not None else None, - } - - if params: - for k, v in params.items(): - assert k not in checkpoint - checkpoint[k] = v - - torch.save(checkpoint, filename) diff --git a/egs/vctk/TTS/vits/vits.py b/egs/vctk/TTS/vits/vits.py deleted file mode 100644 index 6db1cdee1..000000000 --- a/egs/vctk/TTS/vits/vits.py +++ /dev/null @@ -1,609 +0,0 @@ -# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py - -# Copyright 2021 Tomoki Hayashi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""VITS module for GAN-TTS task.""" - -from typing import Any, Dict, Optional, Tuple - -import torch -import torch.nn as nn -from torch.cuda.amp import autocast - -from hifigan import ( - HiFiGANMultiPeriodDiscriminator, - HiFiGANMultiScaleDiscriminator, - HiFiGANMultiScaleMultiPeriodDiscriminator, - HiFiGANPeriodDiscriminator, - HiFiGANScaleDiscriminator, -) -from loss import ( - DiscriminatorAdversarialLoss, - FeatureMatchLoss, - GeneratorAdversarialLoss, - KLDivergenceLoss, - MelSpectrogramLoss, -) -from utils import get_segments -from generator import VITSGenerator - - -AVAILABLE_GENERATERS = { - "vits_generator": VITSGenerator, -} -AVAILABLE_DISCRIMINATORS = { - "hifigan_period_discriminator": HiFiGANPeriodDiscriminator, - "hifigan_scale_discriminator": HiFiGANScaleDiscriminator, - "hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator, - "hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator, - "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA -} - - -class VITS(nn.Module): - """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`""" - - def __init__( - self, - # generator related - vocab_size: int, - feature_dim: int = 513, - sampling_rate: int = 22050, - generator_type: str = "vits_generator", - generator_params: Dict[str, Any] = { - "hidden_channels": 192, - "spks": None, - "langs": None, - "spk_embed_dim": None, - "global_channels": -1, - "segment_size": 32, - "text_encoder_attention_heads": 2, - "text_encoder_ffn_expand": 4, - "text_encoder_cnn_module_kernel": 5, - "text_encoder_blocks": 6, - "text_encoder_dropout_rate": 0.1, - "decoder_kernel_size": 7, - "decoder_channels": 512, - "decoder_upsample_scales": [8, 8, 2, 2], - "decoder_upsample_kernel_sizes": [16, 16, 4, 4], - "decoder_resblock_kernel_sizes": [3, 7, 11], - "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], - "use_weight_norm_in_decoder": True, - "posterior_encoder_kernel_size": 5, - "posterior_encoder_layers": 16, - "posterior_encoder_stacks": 1, - "posterior_encoder_base_dilation": 1, - "posterior_encoder_dropout_rate": 0.0, - "use_weight_norm_in_posterior_encoder": True, - "flow_flows": 4, - "flow_kernel_size": 5, - "flow_base_dilation": 1, - "flow_layers": 4, - "flow_dropout_rate": 0.0, - "use_weight_norm_in_flow": True, - "use_only_mean_in_flow": True, - "stochastic_duration_predictor_kernel_size": 3, - "stochastic_duration_predictor_dropout_rate": 0.5, - "stochastic_duration_predictor_flows": 4, - "stochastic_duration_predictor_dds_conv_layers": 3, - }, - # discriminator related - discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator", - discriminator_params: Dict[str, Any] = { - "scales": 1, - "scale_downsample_pooling": "AvgPool1d", - "scale_downsample_pooling_params": { - "kernel_size": 4, - "stride": 2, - "padding": 2, - }, - "scale_discriminator_params": { - "in_channels": 1, - "out_channels": 1, - "kernel_sizes": [15, 41, 5, 3], - "channels": 128, - "max_downsample_channels": 1024, - "max_groups": 16, - "bias": True, - "downsample_scales": [2, 2, 4, 4, 1], - "nonlinear_activation": "LeakyReLU", - "nonlinear_activation_params": {"negative_slope": 0.1}, - "use_weight_norm": True, - "use_spectral_norm": False, - }, - "follow_official_norm": False, - "periods": [2, 3, 5, 7, 11], - "period_discriminator_params": { - "in_channels": 1, - "out_channels": 1, - "kernel_sizes": [5, 3], - "channels": 32, - "downsample_scales": [3, 3, 3, 3, 1], - "max_downsample_channels": 1024, - "bias": True, - "nonlinear_activation": "LeakyReLU", - "nonlinear_activation_params": {"negative_slope": 0.1}, - "use_weight_norm": True, - "use_spectral_norm": False, - }, - }, - # loss related - generator_adv_loss_params: Dict[str, Any] = { - "average_by_discriminators": False, - "loss_type": "mse", - }, - discriminator_adv_loss_params: Dict[str, Any] = { - "average_by_discriminators": False, - "loss_type": "mse", - }, - feat_match_loss_params: Dict[str, Any] = { - "average_by_discriminators": False, - "average_by_layers": False, - "include_final_outputs": True, - }, - mel_loss_params: Dict[str, Any] = { - "frame_shift": 256, - "frame_length": 1024, - "n_mels": 80, - }, - lambda_adv: float = 1.0, - lambda_mel: float = 45.0, - lambda_feat_match: float = 2.0, - lambda_dur: float = 1.0, - lambda_kl: float = 1.0, - cache_generator_outputs: bool = True, - ): - """Initialize VITS module. - - Args: - idim (int): Input vocabrary size. - odim (int): Acoustic feature dimension. The actual output channels will - be 1 since VITS is the end-to-end text-to-wave model but for the - compatibility odim is used to indicate the acoustic feature dimension. - sampling_rate (int): Sampling rate, not used for the training but it will - be referred in saving waveform during the inference. - generator_type (str): Generator type. - generator_params (Dict[str, Any]): Parameter dict for generator. - discriminator_type (str): Discriminator type. - discriminator_params (Dict[str, Any]): Parameter dict for discriminator. - generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator - adversarial loss. - discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for - discriminator adversarial loss. - feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss. - mel_loss_params (Dict[str, Any]): Parameter dict for mel loss. - lambda_adv (float): Loss scaling coefficient for adversarial loss. - lambda_mel (float): Loss scaling coefficient for mel spectrogram loss. - lambda_feat_match (float): Loss scaling coefficient for feat match loss. - lambda_dur (float): Loss scaling coefficient for duration loss. - lambda_kl (float): Loss scaling coefficient for KL divergence loss. - cache_generator_outputs (bool): Whether to cache generator outputs. - - """ - super().__init__() - - # define modules - generator_class = AVAILABLE_GENERATERS[generator_type] - if generator_type == "vits_generator": - # NOTE(kan-bayashi): Update parameters for the compatibility. - # The idim and odim is automatically decided from input data, - # where idim represents #vocabularies and odim represents - # the input acoustic feature dimension. - generator_params.update(vocabs=vocab_size, aux_channels=feature_dim) - self.generator = generator_class( - **generator_params, - ) - discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type] - self.discriminator = discriminator_class( - **discriminator_params, - ) - self.generator_adv_loss = GeneratorAdversarialLoss( - **generator_adv_loss_params, - ) - self.discriminator_adv_loss = DiscriminatorAdversarialLoss( - **discriminator_adv_loss_params, - ) - self.feat_match_loss = FeatureMatchLoss( - **feat_match_loss_params, - ) - mel_loss_params.update(sampling_rate=sampling_rate) - self.mel_loss = MelSpectrogramLoss( - **mel_loss_params, - ) - self.kl_loss = KLDivergenceLoss() - - # coefficients - self.lambda_adv = lambda_adv - self.lambda_mel = lambda_mel - self.lambda_kl = lambda_kl - self.lambda_feat_match = lambda_feat_match - self.lambda_dur = lambda_dur - - # cache - self.cache_generator_outputs = cache_generator_outputs - self._cache = None - - # store sampling rate for saving wav file - # (not used for the training) - self.sampling_rate = sampling_rate - - # store parameters for test compatibility - self.spks = self.generator.spks - self.langs = self.generator.langs - self.spk_embed_dim = self.generator.spk_embed_dim - - def forward( - self, - text: torch.Tensor, - text_lengths: torch.Tensor, - feats: torch.Tensor, - feats_lengths: torch.Tensor, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - return_sample: bool = False, - sids: Optional[torch.Tensor] = None, - spembs: Optional[torch.Tensor] = None, - lids: Optional[torch.Tensor] = None, - forward_generator: bool = True, - ) -> Tuple[torch.Tensor, Dict[str, Any]]: - """Perform generator forward. - - Args: - text (Tensor): Text index tensor (B, T_text). - text_lengths (Tensor): Text length tensor (B,). - feats (Tensor): Feature tensor (B, T_feats, aux_channels). - feats_lengths (Tensor): Feature length tensor (B,). - speech (Tensor): Speech waveform tensor (B, T_wav). - speech_lengths (Tensor): Speech length tensor (B,). - sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). - spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). - lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). - forward_generator (bool): Whether to forward generator. - - Returns: - - loss (Tensor): Loss scalar tensor. - - stats (Dict[str, float]): Statistics to be monitored. - """ - if forward_generator: - return self._forward_generator( - text=text, - text_lengths=text_lengths, - feats=feats, - feats_lengths=feats_lengths, - speech=speech, - speech_lengths=speech_lengths, - return_sample=return_sample, - sids=sids, - spembs=spembs, - lids=lids, - ) - else: - return self._forward_discrminator( - text=text, - text_lengths=text_lengths, - feats=feats, - feats_lengths=feats_lengths, - speech=speech, - speech_lengths=speech_lengths, - sids=sids, - spembs=spembs, - lids=lids, - ) - - def _forward_generator( - self, - text: torch.Tensor, - text_lengths: torch.Tensor, - feats: torch.Tensor, - feats_lengths: torch.Tensor, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - return_sample: bool = False, - sids: Optional[torch.Tensor] = None, - spembs: Optional[torch.Tensor] = None, - lids: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Dict[str, Any]]: - """Perform generator forward. - - Args: - text (Tensor): Text index tensor (B, T_text). - text_lengths (Tensor): Text length tensor (B,). - feats (Tensor): Feature tensor (B, T_feats, aux_channels). - feats_lengths (Tensor): Feature length tensor (B,). - speech (Tensor): Speech waveform tensor (B, T_wav). - speech_lengths (Tensor): Speech length tensor (B,). - sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). - spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). - lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). - - Returns: - * loss (Tensor): Loss scalar tensor. - * stats (Dict[str, float]): Statistics to be monitored. - """ - # setup - feats = feats.transpose(1, 2) - speech = speech.unsqueeze(1) - - # calculate generator outputs - reuse_cache = True - if not self.cache_generator_outputs or self._cache is None: - reuse_cache = False - outs = self.generator( - text=text, - text_lengths=text_lengths, - feats=feats, - feats_lengths=feats_lengths, - sids=sids, - spembs=spembs, - lids=lids, - ) - else: - outs = self._cache - - # store cache - if self.training and self.cache_generator_outputs and not reuse_cache: - self._cache = outs - - # parse outputs - speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs - _, z_p, m_p, logs_p, _, logs_q = outs_ - speech_ = get_segments( - x=speech, - start_idxs=start_idxs * self.generator.upsample_factor, - segment_size=self.generator.segment_size * self.generator.upsample_factor, - ) - - # calculate discriminator outputs - p_hat = self.discriminator(speech_hat_) - with torch.no_grad(): - # do not store discriminator gradient in generator turn - p = self.discriminator(speech_) - - # calculate losses - with autocast(enabled=False): - if not return_sample: - mel_loss = self.mel_loss(speech_hat_, speech_) - else: - mel_loss, (mel_hat_, mel_) = self.mel_loss( - speech_hat_, speech_, return_mel=True - ) - kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask) - dur_loss = torch.sum(dur_nll.float()) - adv_loss = self.generator_adv_loss(p_hat) - feat_match_loss = self.feat_match_loss(p_hat, p) - - mel_loss = mel_loss * self.lambda_mel - kl_loss = kl_loss * self.lambda_kl - dur_loss = dur_loss * self.lambda_dur - adv_loss = adv_loss * self.lambda_adv - feat_match_loss = feat_match_loss * self.lambda_feat_match - loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss - - stats = dict( - generator_loss=loss.item(), - generator_mel_loss=mel_loss.item(), - generator_kl_loss=kl_loss.item(), - generator_dur_loss=dur_loss.item(), - generator_adv_loss=adv_loss.item(), - generator_feat_match_loss=feat_match_loss.item(), - ) - - if return_sample: - stats["returned_sample"] = ( - speech_hat_[0].data.cpu().numpy(), - speech_[0].data.cpu().numpy(), - mel_hat_[0].data.cpu().numpy(), - mel_[0].data.cpu().numpy(), - ) - - # reset cache - if reuse_cache or not self.training: - self._cache = None - - return loss, stats - - def _forward_discrminator( - self, - text: torch.Tensor, - text_lengths: torch.Tensor, - feats: torch.Tensor, - feats_lengths: torch.Tensor, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - sids: Optional[torch.Tensor] = None, - spembs: Optional[torch.Tensor] = None, - lids: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Dict[str, Any]]: - """Perform discriminator forward. - - Args: - text (Tensor): Text index tensor (B, T_text). - text_lengths (Tensor): Text length tensor (B,). - feats (Tensor): Feature tensor (B, T_feats, aux_channels). - feats_lengths (Tensor): Feature length tensor (B,). - speech (Tensor): Speech waveform tensor (B, T_wav). - speech_lengths (Tensor): Speech length tensor (B,). - sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). - spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). - lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). - - Returns: - * loss (Tensor): Loss scalar tensor. - * stats (Dict[str, float]): Statistics to be monitored. - """ - # setup - feats = feats.transpose(1, 2) - speech = speech.unsqueeze(1) - - # calculate generator outputs - reuse_cache = True - if not self.cache_generator_outputs or self._cache is None: - reuse_cache = False - outs = self.generator( - text=text, - text_lengths=text_lengths, - feats=feats, - feats_lengths=feats_lengths, - sids=sids, - spembs=spembs, - lids=lids, - ) - else: - outs = self._cache - - # store cache - if self.cache_generator_outputs and not reuse_cache: - self._cache = outs - - # parse outputs - speech_hat_, _, _, start_idxs, *_ = outs - speech_ = get_segments( - x=speech, - start_idxs=start_idxs * self.generator.upsample_factor, - segment_size=self.generator.segment_size * self.generator.upsample_factor, - ) - - # calculate discriminator outputs - p_hat = self.discriminator(speech_hat_.detach()) - p = self.discriminator(speech_) - - # calculate losses - with autocast(enabled=False): - real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) - loss = real_loss + fake_loss - - stats = dict( - discriminator_loss=loss.item(), - discriminator_real_loss=real_loss.item(), - discriminator_fake_loss=fake_loss.item(), - ) - - # reset cache - if reuse_cache or not self.training: - self._cache = None - - return loss, stats - - def inference( - self, - text: torch.Tensor, - feats: Optional[torch.Tensor] = None, - sids: Optional[torch.Tensor] = None, - spembs: Optional[torch.Tensor] = None, - lids: Optional[torch.Tensor] = None, - durations: Optional[torch.Tensor] = None, - noise_scale: float = 0.667, - noise_scale_dur: float = 0.8, - alpha: float = 1.0, - max_len: Optional[int] = None, - use_teacher_forcing: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Run inference for single sample. - - Args: - text (Tensor): Input text index tensor (T_text,). - feats (Tensor): Feature tensor (T_feats, aux_channels). - sids (Tensor): Speaker index tensor (1,). - spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,). - lids (Tensor): Language index tensor (1,). - durations (Tensor): Ground-truth duration tensor (T_text,). - noise_scale (float): Noise scale value for flow. - noise_scale_dur (float): Noise scale value for duration predictor. - alpha (float): Alpha parameter to control the speed of generated speech. - max_len (Optional[int]): Maximum length. - use_teacher_forcing (bool): Whether to use teacher forcing. - - Returns: - * wav (Tensor): Generated waveform tensor (T_wav,). - * att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text). - * duration (Tensor): Predicted duration tensor (T_text,). - """ - # setup - text = text[None] - text_lengths = torch.tensor( - [text.size(1)], - dtype=torch.long, - device=text.device, - ) - if sids is not None: - sids = sids.view(1) - if lids is not None: - lids = lids.view(1) - if durations is not None: - durations = durations.view(1, 1, -1) - - # inference - if use_teacher_forcing: - assert feats is not None - feats = feats[None].transpose(1, 2) - feats_lengths = torch.tensor( - [feats.size(2)], - dtype=torch.long, - device=feats.device, - ) - wav, att_w, dur = self.generator.inference( - text=text, - text_lengths=text_lengths, - feats=feats, - feats_lengths=feats_lengths, - sids=sids, - spembs=spembs, - lids=lids, - max_len=max_len, - use_teacher_forcing=use_teacher_forcing, - ) - else: - wav, att_w, dur = self.generator.inference( - text=text, - text_lengths=text_lengths, - sids=sids, - spembs=spembs, - lids=lids, - dur=durations, - noise_scale=noise_scale, - noise_scale_dur=noise_scale_dur, - alpha=alpha, - max_len=max_len, - ) - return wav.view(-1), att_w[0], dur[0] - - def inference_batch( - self, - text: torch.Tensor, - text_lengths: torch.Tensor, - sids: Optional[torch.Tensor] = None, - durations: Optional[torch.Tensor] = None, - noise_scale: float = 0.667, - noise_scale_dur: float = 0.8, - alpha: float = 1.0, - max_len: Optional[int] = None, - use_teacher_forcing: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Run inference for one batch. - - Args: - text (Tensor): Input text index tensor (B, T_text). - text_lengths (Tensor): Input text index tensor (B,). - sids (Tensor): Speaker index tensor (B,). - noise_scale (float): Noise scale value for flow. - noise_scale_dur (float): Noise scale value for duration predictor. - alpha (float): Alpha parameter to control the speed of generated speech. - max_len (Optional[int]): Maximum length. - - Returns: - * wav (Tensor): Generated waveform tensor (B, T_wav). - * att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text). - * duration (Tensor): Predicted duration tensor (B, T_text). - """ - # inference - wav, att_w, dur = self.generator.inference( - text=text, - text_lengths=text_lengths, - sids=sids, - noise_scale=noise_scale, - noise_scale_dur=noise_scale_dur, - alpha=alpha, - max_len=max_len, - ) - return wav, att_w, dur diff --git a/egs/vctk/TTS/vits/wavenet.py b/egs/vctk/TTS/vits/wavenet.py deleted file mode 100644 index fbe1be52b..000000000 --- a/egs/vctk/TTS/vits/wavenet.py +++ /dev/null @@ -1,349 +0,0 @@ -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py - -# Copyright 2021 Tomoki Hayashi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""WaveNet modules. - -This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. - -""" - -import math -import logging - -from typing import Optional, Tuple - -import torch -import torch.nn.functional as F - - -class WaveNet(torch.nn.Module): - """WaveNet with global conditioning.""" - - def __init__( - self, - in_channels: int = 1, - out_channels: int = 1, - kernel_size: int = 3, - layers: int = 30, - stacks: int = 3, - base_dilation: int = 2, - residual_channels: int = 64, - aux_channels: int = -1, - gate_channels: int = 128, - skip_channels: int = 64, - global_channels: int = -1, - dropout_rate: float = 0.0, - bias: bool = True, - use_weight_norm: bool = True, - use_first_conv: bool = False, - use_last_conv: bool = False, - scale_residual: bool = False, - scale_skip_connect: bool = False, - ): - """Initialize WaveNet module. - - Args: - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. - kernel_size (int): Kernel size of dilated convolution. - layers (int): Number of residual block layers. - stacks (int): Number of stacks i.e., dilation cycles. - base_dilation (int): Base dilation factor. - residual_channels (int): Number of channels in residual conv. - gate_channels (int): Number of channels in gated conv. - skip_channels (int): Number of channels in skip conv. - aux_channels (int): Number of channels for local conditioning feature. - global_channels (int): Number of channels for global conditioning feature. - dropout_rate (float): Dropout rate. 0.0 means no dropout applied. - bias (bool): Whether to use bias parameter in conv layer. - use_weight_norm (bool): Whether to use weight norm. If set to true, it will - be applied to all of the conv layers. - use_first_conv (bool): Whether to use the first conv layers. - use_last_conv (bool): Whether to use the last conv layers. - scale_residual (bool): Whether to scale the residual outputs. - scale_skip_connect (bool): Whether to scale the skip connection outputs. - - """ - super().__init__() - self.layers = layers - self.stacks = stacks - self.kernel_size = kernel_size - self.base_dilation = base_dilation - self.use_first_conv = use_first_conv - self.use_last_conv = use_last_conv - self.scale_skip_connect = scale_skip_connect - - # check the number of layers and stacks - assert layers % stacks == 0 - layers_per_stack = layers // stacks - - # define first convolution - if self.use_first_conv: - self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True) - - # define residual blocks - self.conv_layers = torch.nn.ModuleList() - for layer in range(layers): - dilation = base_dilation ** (layer % layers_per_stack) - conv = ResidualBlock( - kernel_size=kernel_size, - residual_channels=residual_channels, - gate_channels=gate_channels, - skip_channels=skip_channels, - aux_channels=aux_channels, - global_channels=global_channels, - dilation=dilation, - dropout_rate=dropout_rate, - bias=bias, - scale_residual=scale_residual, - ) - self.conv_layers += [conv] - - # define output layers - if self.use_last_conv: - self.last_conv = torch.nn.Sequential( - torch.nn.ReLU(inplace=True), - Conv1d1x1(skip_channels, skip_channels, bias=True), - torch.nn.ReLU(inplace=True), - Conv1d1x1(skip_channels, out_channels, bias=True), - ) - - # apply weight norm - if use_weight_norm: - self.apply_weight_norm() - - def forward( - self, - x: torch.Tensor, - x_mask: Optional[torch.Tensor] = None, - c: Optional[torch.Tensor] = None, - g: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Calculate forward propagation. - - Args: - x (Tensor): Input noise signal (B, 1, T) if use_first_conv else - (B, residual_channels, T). - x_mask (Optional[Tensor]): Mask tensor (B, 1, T). - c (Optional[Tensor]): Local conditioning features (B, aux_channels, T). - g (Optional[Tensor]): Global conditioning features (B, global_channels, 1). - - Returns: - Tensor: Output tensor (B, out_channels, T) if use_last_conv else - (B, residual_channels, T). - - """ - # encode to hidden representation - if self.use_first_conv: - x = self.first_conv(x) - - # residual block - skips = 0.0 - for f in self.conv_layers: - x, h = f(x, x_mask=x_mask, c=c, g=g) - skips = skips + h - x = skips - if self.scale_skip_connect: - x = x * math.sqrt(1.0 / len(self.conv_layers)) - - # apply final layers - if self.use_last_conv: - x = self.last_conv(x) - - return x - - def remove_weight_norm(self): - """Remove weight normalization module from all of the layers.""" - - def _remove_weight_norm(m: torch.nn.Module): - try: - logging.debug(f"Weight norm is removed from {m}.") - torch.nn.utils.remove_weight_norm(m) - except ValueError: # this module didn't have weight norm - return - - self.apply(_remove_weight_norm) - - def apply_weight_norm(self): - """Apply weight normalization module from all of the layers.""" - - def _apply_weight_norm(m: torch.nn.Module): - if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): - torch.nn.utils.weight_norm(m) - logging.debug(f"Weight norm is applied to {m}.") - - self.apply(_apply_weight_norm) - - @staticmethod - def _get_receptive_field_size( - layers: int, - stacks: int, - kernel_size: int, - base_dilation: int, - ) -> int: - assert layers % stacks == 0 - layers_per_cycle = layers // stacks - dilations = [base_dilation ** (i % layers_per_cycle) for i in range(layers)] - return (kernel_size - 1) * sum(dilations) + 1 - - @property - def receptive_field_size(self) -> int: - """Return receptive field size.""" - return self._get_receptive_field_size( - self.layers, self.stacks, self.kernel_size, self.base_dilation - ) - - -class Conv1d(torch.nn.Conv1d): - """Conv1d module with customized initialization.""" - - def __init__(self, *args, **kwargs): - """Initialize Conv1d module.""" - super().__init__(*args, **kwargs) - - def reset_parameters(self): - """Reset parameters.""" - torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu") - if self.bias is not None: - torch.nn.init.constant_(self.bias, 0.0) - - -class Conv1d1x1(Conv1d): - """1x1 Conv1d with customized initialization.""" - - def __init__(self, in_channels: int, out_channels: int, bias: bool): - """Initialize 1x1 Conv1d module.""" - super().__init__( - in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias - ) - - -class ResidualBlock(torch.nn.Module): - """Residual block module in WaveNet.""" - - def __init__( - self, - kernel_size: int = 3, - residual_channels: int = 64, - gate_channels: int = 128, - skip_channels: int = 64, - aux_channels: int = 80, - global_channels: int = -1, - dropout_rate: float = 0.0, - dilation: int = 1, - bias: bool = True, - scale_residual: bool = False, - ): - """Initialize ResidualBlock module. - - Args: - kernel_size (int): Kernel size of dilation convolution layer. - residual_channels (int): Number of channels for residual connection. - skip_channels (int): Number of channels for skip connection. - aux_channels (int): Number of local conditioning channels. - dropout (float): Dropout probability. - dilation (int): Dilation factor. - bias (bool): Whether to add bias parameter in convolution layers. - scale_residual (bool): Whether to scale the residual outputs. - - """ - super().__init__() - self.dropout_rate = dropout_rate - self.residual_channels = residual_channels - self.skip_channels = skip_channels - self.scale_residual = scale_residual - - # check - assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." - assert gate_channels % 2 == 0 - - # dilation conv - padding = (kernel_size - 1) // 2 * dilation - self.conv = Conv1d( - residual_channels, - gate_channels, - kernel_size, - padding=padding, - dilation=dilation, - bias=bias, - ) - - # local conditioning - if aux_channels > 0: - self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False) - else: - self.conv1x1_aux = None - - # global conditioning - if global_channels > 0: - self.conv1x1_glo = Conv1d1x1(global_channels, gate_channels, bias=False) - else: - self.conv1x1_glo = None - - # conv output is split into two groups - gate_out_channels = gate_channels // 2 - - # NOTE(kan-bayashi): concat two convs into a single conv for the efficiency - # (integrate res 1x1 + skip 1x1 convs) - self.conv1x1_out = Conv1d1x1( - gate_out_channels, residual_channels + skip_channels, bias=bias - ) - - def forward( - self, - x: torch.Tensor, - x_mask: Optional[torch.Tensor] = None, - c: Optional[torch.Tensor] = None, - g: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculate forward propagation. - - Args: - x (Tensor): Input tensor (B, residual_channels, T). - x_mask Optional[torch.Tensor]: Mask tensor (B, 1, T). - c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T). - g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). - - Returns: - Tensor: Output tensor for residual connection (B, residual_channels, T). - Tensor: Output tensor for skip connection (B, skip_channels, T). - - """ - residual = x - x = F.dropout(x, p=self.dropout_rate, training=self.training) - x = self.conv(x) - - # split into two part for gated activation - splitdim = 1 - xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) - - # local conditioning - if c is not None: - c = self.conv1x1_aux(c) - ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) - xa, xb = xa + ca, xb + cb - - # global conditioning - if g is not None: - g = self.conv1x1_glo(g) - ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim) - xa, xb = xa + ga, xb + gb - - x = torch.tanh(xa) * torch.sigmoid(xb) - - # residual + skip 1x1 conv - x = self.conv1x1_out(x) - if x_mask is not None: - x = x * x_mask - - # split integrated conv results - x, s = x.split([self.residual_channels, self.skip_channels], dim=1) - - # for residual connection - x = x + residual - if self.scale_residual: - x = x * math.sqrt(0.5) - - return x, s