removed redundant files

This commit is contained in:
jinzr 2023-12-01 00:12:35 +08:00
parent b7efcbf154
commit ee718f1da1
7 changed files with 0 additions and 1604 deletions

View File

@ -1,81 +0,0 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/__init__.py
"""Maximum path calculation module.
This code is based on https://github.com/jaywalnut310/vits.
"""
import warnings
import numpy as np
import torch
from numba import njit, prange
try:
from .core import maximum_path_c
is_cython_avalable = True
except ImportError:
is_cython_avalable = False
warnings.warn(
"Cython version is not available. Fallback to 'EXPERIMETAL' numba version. "
"If you want to use the cython version, please build it as follows: "
"`cd espnet2/gan_tts/vits/monotonic_align; python setup.py build_ext --inplace`"
)
def maximum_path(neg_x_ent: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
"""Calculate maximum path.
Args:
neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text).
attn_mask (Tensor): Attention mask (B, T_feats, T_text).
Returns:
Tensor: Maximum path tensor (B, T_feats, T_text).
"""
device, dtype = neg_x_ent.device, neg_x_ent.dtype
neg_x_ent = neg_x_ent.cpu().numpy().astype(np.float32)
path = np.zeros(neg_x_ent.shape, dtype=np.int32)
t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32)
t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32)
if is_cython_avalable:
maximum_path_c(path, neg_x_ent, t_t_max, t_s_max)
else:
maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max)
return torch.from_numpy(path).to(device=device, dtype=dtype)
@njit
def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf):
"""Calculate a single maximum path with numba."""
index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[y - 1, x]
if x == 0:
if y == 0:
v_prev = 0.0
else:
v_prev = max_neg_val
else:
v_prev = value[y - 1, x - 1]
value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1):
path[y, index] = 1
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
index = index - 1
@njit(parallel=True)
def maximum_path_numba(paths, values, t_ys, t_xs):
"""Calculate batch maximum path with numba."""
for i in prange(paths.shape[0]):
maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i])

View File

@ -1,51 +0,0 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/core.pyx
"""Maximum path calculation module with cython optimization.
This code is copied from https://github.com/jaywalnut310/vits and modifed code format.
"""
cimport cython
from cython.parallel import prange
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
cdef int x
cdef int y
cdef float v_prev
cdef float v_cur
cdef float tmp
cdef int index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[y - 1, x]
if x == 0:
if y == 0:
v_prev = 0.0
else:
v_prev = max_neg_val
else:
v_prev = value[y - 1, x - 1]
value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1):
path[y, index] = 1
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
index = index - 1
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
cdef int b = paths.shape[0]
cdef int i
for i in prange(b, nogil=True):
maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])

View File

@ -1,31 +0,0 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py
"""Setup cython code."""
from Cython.Build import cythonize
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext as _build_ext
class build_ext(_build_ext):
"""Overwrite build_ext."""
def finalize_options(self):
"""Prevent numpy from thinking it is still in its setup process."""
_build_ext.finalize_options(self)
__builtins__.__NUMPY_SETUP__ = False
import numpy
self.include_dirs.append(numpy.get_include())
exts = [
Extension(
name="core",
sources=["core.pyx"],
)
]
setup(
name="monotonic_align",
ext_modules=cythonize(exts, language_level=3),
cmdclass={"build_ext": build_ext},
)

View File

@ -1,218 +0,0 @@
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py
"""Flow-related transformation.
This code is derived from https://github.com/bayesiains/nflows.
"""
import numpy as np
import torch
from torch.nn import functional as F
DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
# TODO(kan-bayashi): Documentation and type hint
def piecewise_rational_quadratic_transform(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs
)
return outputs, logabsdet
# TODO(kan-bayashi): Documentation and type hint
def unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs)
if tails == "linear":
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError("{} tails are not implemented.".format(tails))
(
outputs[inside_interval_mask],
logabsdet[inside_interval_mask],
) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
)
return outputs, logabsdet
# TODO(kan-bayashi): Documentation and type hint
def rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0.0,
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse:
bin_idx = _searchsorted(cumheights, inputs)[..., None]
else:
bin_idx = _searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
delta = heights / widths
input_delta = delta.gather(-1, bin_idx)[..., 0]
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse:
a = (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
) + input_heights * (input_delta - input_derivatives)
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
)
c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
root = (2 * c) / (-b - torch.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet
else:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet
def _searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1

View File

@ -1,265 +0,0 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import collections
import logging
import torch
import torch.nn as nn
import torch.distributed as dist
from lhotse.dataset.sampling.base import CutSampler
from pathlib import Path
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
def get_random_segments(
x: torch.Tensor,
x_lengths: torch.Tensor,
segment_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get random segments.
Args:
x (Tensor): Input tensor (B, C, T).
x_lengths (Tensor): Length tensor (B,).
segment_size (int): Segment size.
Returns:
Tensor: Segmented tensor (B, C, segment_size).
Tensor: Start index tensor (B,).
"""
b, c, t = x.size()
max_start_idx = x_lengths - segment_size
max_start_idx[max_start_idx < 0] = 0
start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to(
dtype=torch.long,
)
segments = get_segments(x, start_idxs, segment_size)
return segments, start_idxs
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
def get_segments(
x: torch.Tensor,
start_idxs: torch.Tensor,
segment_size: int,
) -> torch.Tensor:
"""Get segments.
Args:
x (Tensor): Input tensor (B, C, T).
start_idxs (Tensor): Start index tensor (B,).
segment_size (int): Segment size.
Returns:
Tensor: Segmented tensor (B, C, segment_size).
"""
b, c, t = x.size()
segments = x.new_zeros(b, c, segment_size)
for i, start_idx in enumerate(start_idxs):
segments[i] = x[i, :, start_idx : start_idx + segment_size]
return segments
# from https://github.com/jaywalnut310/vit://github.com/jaywalnut310/vits/blob/main/commons.py
def intersperse(sequence, item=0):
result = [item] * (len(sequence) * 2 + 1)
result[1::2] = sequence
return result
# from https://github.com/jaywalnut310/vits/blob/main/utils.py
MATPLOTLIB_FLAG = False
def plot_feature(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
class MetricsTracker(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
# This class will play a role as metrics tracker.
# It can record many metrics, including but not limited to loss.
super(MetricsTracker, self).__init__(int)
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans = ""
for k, v in self.norm_items():
norm_value = "%.4g" % v
ans += str(k) + "=" + str(norm_value) + ", "
samples = "%.2f" % self["samples"]
ans += "over " + str(samples) + " samples."
return ans
def norm_items(self) -> List[Tuple[str, float]]:
"""
Returns a list of pairs, like:
[('loss_1', 0.1), ('loss_2', 0.07)]
"""
samples = self["samples"] if "samples" in self else 1
ans = []
for k, v in self.items():
if k == "samples":
continue
norm_value = float(v) / samples
ans.append((k, norm_value))
return ans
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([float(self[k]) for k in keys], device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(
self,
tb_writer: SummaryWriter,
prefix: str,
batch_idx: int,
) -> None:
"""Add logging information to a TensorBoard writer.
Args:
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
# checkpoint saving and loading
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
def save_checkpoint(
filename: Path,
model: Union[nn.Module, DDP],
params: Optional[Dict[str, Any]] = None,
optimizer_g: Optional[Optimizer] = None,
optimizer_d: Optional[Optimizer] = None,
scheduler_g: Optional[LRSchedulerType] = None,
scheduler_d: Optional[LRSchedulerType] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
) -> None:
"""Save training information to a file.
Args:
filename:
The checkpoint filename.
model:
The model to be saved. We only save its `state_dict()`.
model_avg:
The stored model averaged from the start of training.
params:
User defined parameters, e.g., epoch, loss.
optimizer_g:
The optimizer for generator used in the training.
Its `state_dict` will be saved.
optimizer_d:
The optimizer for discriminator used in the training.
Its `state_dict` will be saved.
scheduler_g:
The learning rate scheduler for generator used in the training.
Its `state_dict` will be saved.
scheduler_d:
The learning rate scheduler for discriminator used in the training.
Its `state_dict` will be saved.
scalar:
The GradScaler to be saved. We only save its `state_dict()`.
rank:
Used in DDP. We save checkpoint only for the node whose rank is 0.
Returns:
Return None.
"""
if rank != 0:
return
logging.info(f"Saving checkpoint to {filename}")
if isinstance(model, DDP):
model = model.module
checkpoint = {
"model": model.state_dict(),
"optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None,
"optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None,
"scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None,
"scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None,
"grad_scaler": scaler.state_dict() if scaler is not None else None,
"sampler": sampler.state_dict() if sampler is not None else None,
}
if params:
for k, v in params.items():
assert k not in checkpoint
checkpoint[k] = v
torch.save(checkpoint, filename)

View File

@ -1,609 +0,0 @@
# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""VITS module for GAN-TTS task."""
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
from hifigan import (
HiFiGANMultiPeriodDiscriminator,
HiFiGANMultiScaleDiscriminator,
HiFiGANMultiScaleMultiPeriodDiscriminator,
HiFiGANPeriodDiscriminator,
HiFiGANScaleDiscriminator,
)
from loss import (
DiscriminatorAdversarialLoss,
FeatureMatchLoss,
GeneratorAdversarialLoss,
KLDivergenceLoss,
MelSpectrogramLoss,
)
from utils import get_segments
from generator import VITSGenerator
AVAILABLE_GENERATERS = {
"vits_generator": VITSGenerator,
}
AVAILABLE_DISCRIMINATORS = {
"hifigan_period_discriminator": HiFiGANPeriodDiscriminator,
"hifigan_scale_discriminator": HiFiGANScaleDiscriminator,
"hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator,
"hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator,
"hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
}
class VITS(nn.Module):
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
def __init__(
self,
# generator related
vocab_size: int,
feature_dim: int = 513,
sampling_rate: int = 22050,
generator_type: str = "vits_generator",
generator_params: Dict[str, Any] = {
"hidden_channels": 192,
"spks": None,
"langs": None,
"spk_embed_dim": None,
"global_channels": -1,
"segment_size": 32,
"text_encoder_attention_heads": 2,
"text_encoder_ffn_expand": 4,
"text_encoder_cnn_module_kernel": 5,
"text_encoder_blocks": 6,
"text_encoder_dropout_rate": 0.1,
"decoder_kernel_size": 7,
"decoder_channels": 512,
"decoder_upsample_scales": [8, 8, 2, 2],
"decoder_upsample_kernel_sizes": [16, 16, 4, 4],
"decoder_resblock_kernel_sizes": [3, 7, 11],
"decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"use_weight_norm_in_decoder": True,
"posterior_encoder_kernel_size": 5,
"posterior_encoder_layers": 16,
"posterior_encoder_stacks": 1,
"posterior_encoder_base_dilation": 1,
"posterior_encoder_dropout_rate": 0.0,
"use_weight_norm_in_posterior_encoder": True,
"flow_flows": 4,
"flow_kernel_size": 5,
"flow_base_dilation": 1,
"flow_layers": 4,
"flow_dropout_rate": 0.0,
"use_weight_norm_in_flow": True,
"use_only_mean_in_flow": True,
"stochastic_duration_predictor_kernel_size": 3,
"stochastic_duration_predictor_dropout_rate": 0.5,
"stochastic_duration_predictor_flows": 4,
"stochastic_duration_predictor_dds_conv_layers": 3,
},
# discriminator related
discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator",
discriminator_params: Dict[str, Any] = {
"scales": 1,
"scale_downsample_pooling": "AvgPool1d",
"scale_downsample_pooling_params": {
"kernel_size": 4,
"stride": 2,
"padding": 2,
},
"scale_discriminator_params": {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [15, 41, 5, 3],
"channels": 128,
"max_downsample_channels": 1024,
"max_groups": 16,
"bias": True,
"downsample_scales": [2, 2, 4, 4, 1],
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
"use_weight_norm": True,
"use_spectral_norm": False,
},
"follow_official_norm": False,
"periods": [2, 3, 5, 7, 11],
"period_discriminator_params": {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [5, 3],
"channels": 32,
"downsample_scales": [3, 3, 3, 3, 1],
"max_downsample_channels": 1024,
"bias": True,
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
"use_weight_norm": True,
"use_spectral_norm": False,
},
},
# loss related
generator_adv_loss_params: Dict[str, Any] = {
"average_by_discriminators": False,
"loss_type": "mse",
},
discriminator_adv_loss_params: Dict[str, Any] = {
"average_by_discriminators": False,
"loss_type": "mse",
},
feat_match_loss_params: Dict[str, Any] = {
"average_by_discriminators": False,
"average_by_layers": False,
"include_final_outputs": True,
},
mel_loss_params: Dict[str, Any] = {
"frame_shift": 256,
"frame_length": 1024,
"n_mels": 80,
},
lambda_adv: float = 1.0,
lambda_mel: float = 45.0,
lambda_feat_match: float = 2.0,
lambda_dur: float = 1.0,
lambda_kl: float = 1.0,
cache_generator_outputs: bool = True,
):
"""Initialize VITS module.
Args:
idim (int): Input vocabrary size.
odim (int): Acoustic feature dimension. The actual output channels will
be 1 since VITS is the end-to-end text-to-wave model but for the
compatibility odim is used to indicate the acoustic feature dimension.
sampling_rate (int): Sampling rate, not used for the training but it will
be referred in saving waveform during the inference.
generator_type (str): Generator type.
generator_params (Dict[str, Any]): Parameter dict for generator.
discriminator_type (str): Discriminator type.
discriminator_params (Dict[str, Any]): Parameter dict for discriminator.
generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator
adversarial loss.
discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for
discriminator adversarial loss.
feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss.
mel_loss_params (Dict[str, Any]): Parameter dict for mel loss.
lambda_adv (float): Loss scaling coefficient for adversarial loss.
lambda_mel (float): Loss scaling coefficient for mel spectrogram loss.
lambda_feat_match (float): Loss scaling coefficient for feat match loss.
lambda_dur (float): Loss scaling coefficient for duration loss.
lambda_kl (float): Loss scaling coefficient for KL divergence loss.
cache_generator_outputs (bool): Whether to cache generator outputs.
"""
super().__init__()
# define modules
generator_class = AVAILABLE_GENERATERS[generator_type]
if generator_type == "vits_generator":
# NOTE(kan-bayashi): Update parameters for the compatibility.
# The idim and odim is automatically decided from input data,
# where idim represents #vocabularies and odim represents
# the input acoustic feature dimension.
generator_params.update(vocabs=vocab_size, aux_channels=feature_dim)
self.generator = generator_class(
**generator_params,
)
discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type]
self.discriminator = discriminator_class(
**discriminator_params,
)
self.generator_adv_loss = GeneratorAdversarialLoss(
**generator_adv_loss_params,
)
self.discriminator_adv_loss = DiscriminatorAdversarialLoss(
**discriminator_adv_loss_params,
)
self.feat_match_loss = FeatureMatchLoss(
**feat_match_loss_params,
)
mel_loss_params.update(sampling_rate=sampling_rate)
self.mel_loss = MelSpectrogramLoss(
**mel_loss_params,
)
self.kl_loss = KLDivergenceLoss()
# coefficients
self.lambda_adv = lambda_adv
self.lambda_mel = lambda_mel
self.lambda_kl = lambda_kl
self.lambda_feat_match = lambda_feat_match
self.lambda_dur = lambda_dur
# cache
self.cache_generator_outputs = cache_generator_outputs
self._cache = None
# store sampling rate for saving wav file
# (not used for the training)
self.sampling_rate = sampling_rate
# store parameters for test compatibility
self.spks = self.generator.spks
self.langs = self.generator.langs
self.spk_embed_dim = self.generator.spk_embed_dim
def forward(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
feats: torch.Tensor,
feats_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
return_sample: bool = False,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
forward_generator: bool = True,
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform generator forward.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor): Feature length tensor (B,).
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
forward_generator (bool): Whether to forward generator.
Returns:
- loss (Tensor): Loss scalar tensor.
- stats (Dict[str, float]): Statistics to be monitored.
"""
if forward_generator:
return self._forward_generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
speech=speech,
speech_lengths=speech_lengths,
return_sample=return_sample,
sids=sids,
spembs=spembs,
lids=lids,
)
else:
return self._forward_discrminator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
speech=speech,
speech_lengths=speech_lengths,
sids=sids,
spembs=spembs,
lids=lids,
)
def _forward_generator(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
feats: torch.Tensor,
feats_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
return_sample: bool = False,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform generator forward.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor): Feature length tensor (B,).
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
"""
# setup
feats = feats.transpose(1, 2)
speech = speech.unsqueeze(1)
# calculate generator outputs
reuse_cache = True
if not self.cache_generator_outputs or self._cache is None:
reuse_cache = False
outs = self.generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids,
)
else:
outs = self._cache
# store cache
if self.training and self.cache_generator_outputs and not reuse_cache:
self._cache = outs
# parse outputs
speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
_, z_p, m_p, logs_p, _, logs_q = outs_
speech_ = get_segments(
x=speech,
start_idxs=start_idxs * self.generator.upsample_factor,
segment_size=self.generator.segment_size * self.generator.upsample_factor,
)
# calculate discriminator outputs
p_hat = self.discriminator(speech_hat_)
with torch.no_grad():
# do not store discriminator gradient in generator turn
p = self.discriminator(speech_)
# calculate losses
with autocast(enabled=False):
if not return_sample:
mel_loss = self.mel_loss(speech_hat_, speech_)
else:
mel_loss, (mel_hat_, mel_) = self.mel_loss(
speech_hat_, speech_, return_mel=True
)
kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
dur_loss = torch.sum(dur_nll.float())
adv_loss = self.generator_adv_loss(p_hat)
feat_match_loss = self.feat_match_loss(p_hat, p)
mel_loss = mel_loss * self.lambda_mel
kl_loss = kl_loss * self.lambda_kl
dur_loss = dur_loss * self.lambda_dur
adv_loss = adv_loss * self.lambda_adv
feat_match_loss = feat_match_loss * self.lambda_feat_match
loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss
stats = dict(
generator_loss=loss.item(),
generator_mel_loss=mel_loss.item(),
generator_kl_loss=kl_loss.item(),
generator_dur_loss=dur_loss.item(),
generator_adv_loss=adv_loss.item(),
generator_feat_match_loss=feat_match_loss.item(),
)
if return_sample:
stats["returned_sample"] = (
speech_hat_[0].data.cpu().numpy(),
speech_[0].data.cpu().numpy(),
mel_hat_[0].data.cpu().numpy(),
mel_[0].data.cpu().numpy(),
)
# reset cache
if reuse_cache or not self.training:
self._cache = None
return loss, stats
def _forward_discrminator(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
feats: torch.Tensor,
feats_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform discriminator forward.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor): Feature length tensor (B,).
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
"""
# setup
feats = feats.transpose(1, 2)
speech = speech.unsqueeze(1)
# calculate generator outputs
reuse_cache = True
if not self.cache_generator_outputs or self._cache is None:
reuse_cache = False
outs = self.generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids,
)
else:
outs = self._cache
# store cache
if self.cache_generator_outputs and not reuse_cache:
self._cache = outs
# parse outputs
speech_hat_, _, _, start_idxs, *_ = outs
speech_ = get_segments(
x=speech,
start_idxs=start_idxs * self.generator.upsample_factor,
segment_size=self.generator.segment_size * self.generator.upsample_factor,
)
# calculate discriminator outputs
p_hat = self.discriminator(speech_hat_.detach())
p = self.discriminator(speech_)
# calculate losses
with autocast(enabled=False):
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
loss = real_loss + fake_loss
stats = dict(
discriminator_loss=loss.item(),
discriminator_real_loss=real_loss.item(),
discriminator_fake_loss=fake_loss.item(),
)
# reset cache
if reuse_cache or not self.training:
self._cache = None
return loss, stats
def inference(
self,
text: torch.Tensor,
feats: Optional[torch.Tensor] = None,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
durations: Optional[torch.Tensor] = None,
noise_scale: float = 0.667,
noise_scale_dur: float = 0.8,
alpha: float = 1.0,
max_len: Optional[int] = None,
use_teacher_forcing: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Run inference for single sample.
Args:
text (Tensor): Input text index tensor (T_text,).
feats (Tensor): Feature tensor (T_feats, aux_channels).
sids (Tensor): Speaker index tensor (1,).
spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,).
lids (Tensor): Language index tensor (1,).
durations (Tensor): Ground-truth duration tensor (T_text,).
noise_scale (float): Noise scale value for flow.
noise_scale_dur (float): Noise scale value for duration predictor.
alpha (float): Alpha parameter to control the speed of generated speech.
max_len (Optional[int]): Maximum length.
use_teacher_forcing (bool): Whether to use teacher forcing.
Returns:
* wav (Tensor): Generated waveform tensor (T_wav,).
* att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
* duration (Tensor): Predicted duration tensor (T_text,).
"""
# setup
text = text[None]
text_lengths = torch.tensor(
[text.size(1)],
dtype=torch.long,
device=text.device,
)
if sids is not None:
sids = sids.view(1)
if lids is not None:
lids = lids.view(1)
if durations is not None:
durations = durations.view(1, 1, -1)
# inference
if use_teacher_forcing:
assert feats is not None
feats = feats[None].transpose(1, 2)
feats_lengths = torch.tensor(
[feats.size(2)],
dtype=torch.long,
device=feats.device,
)
wav, att_w, dur = self.generator.inference(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids,
max_len=max_len,
use_teacher_forcing=use_teacher_forcing,
)
else:
wav, att_w, dur = self.generator.inference(
text=text,
text_lengths=text_lengths,
sids=sids,
spembs=spembs,
lids=lids,
dur=durations,
noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur,
alpha=alpha,
max_len=max_len,
)
return wav.view(-1), att_w[0], dur[0]
def inference_batch(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
sids: Optional[torch.Tensor] = None,
durations: Optional[torch.Tensor] = None,
noise_scale: float = 0.667,
noise_scale_dur: float = 0.8,
alpha: float = 1.0,
max_len: Optional[int] = None,
use_teacher_forcing: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Run inference for one batch.
Args:
text (Tensor): Input text index tensor (B, T_text).
text_lengths (Tensor): Input text index tensor (B,).
sids (Tensor): Speaker index tensor (B,).
noise_scale (float): Noise scale value for flow.
noise_scale_dur (float): Noise scale value for duration predictor.
alpha (float): Alpha parameter to control the speed of generated speech.
max_len (Optional[int]): Maximum length.
Returns:
* wav (Tensor): Generated waveform tensor (B, T_wav).
* att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text).
* duration (Tensor): Predicted duration tensor (B, T_text).
"""
# inference
wav, att_w, dur = self.generator.inference(
text=text,
text_lengths=text_lengths,
sids=sids,
noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur,
alpha=alpha,
max_len=max_len,
)
return wav, att_w, dur

View File

@ -1,349 +0,0 @@
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""WaveNet modules.
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
"""
import math
import logging
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
class WaveNet(torch.nn.Module):
"""WaveNet with global conditioning."""
def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
kernel_size: int = 3,
layers: int = 30,
stacks: int = 3,
base_dilation: int = 2,
residual_channels: int = 64,
aux_channels: int = -1,
gate_channels: int = 128,
skip_channels: int = 64,
global_channels: int = -1,
dropout_rate: float = 0.0,
bias: bool = True,
use_weight_norm: bool = True,
use_first_conv: bool = False,
use_last_conv: bool = False,
scale_residual: bool = False,
scale_skip_connect: bool = False,
):
"""Initialize WaveNet module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Kernel size of dilated convolution.
layers (int): Number of residual block layers.
stacks (int): Number of stacks i.e., dilation cycles.
base_dilation (int): Base dilation factor.
residual_channels (int): Number of channels in residual conv.
gate_channels (int): Number of channels in gated conv.
skip_channels (int): Number of channels in skip conv.
aux_channels (int): Number of channels for local conditioning feature.
global_channels (int): Number of channels for global conditioning feature.
dropout_rate (float): Dropout rate. 0.0 means no dropout applied.
bias (bool): Whether to use bias parameter in conv layer.
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
be applied to all of the conv layers.
use_first_conv (bool): Whether to use the first conv layers.
use_last_conv (bool): Whether to use the last conv layers.
scale_residual (bool): Whether to scale the residual outputs.
scale_skip_connect (bool): Whether to scale the skip connection outputs.
"""
super().__init__()
self.layers = layers
self.stacks = stacks
self.kernel_size = kernel_size
self.base_dilation = base_dilation
self.use_first_conv = use_first_conv
self.use_last_conv = use_last_conv
self.scale_skip_connect = scale_skip_connect
# check the number of layers and stacks
assert layers % stacks == 0
layers_per_stack = layers // stacks
# define first convolution
if self.use_first_conv:
self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
# define residual blocks
self.conv_layers = torch.nn.ModuleList()
for layer in range(layers):
dilation = base_dilation ** (layer % layers_per_stack)
conv = ResidualBlock(
kernel_size=kernel_size,
residual_channels=residual_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
global_channels=global_channels,
dilation=dilation,
dropout_rate=dropout_rate,
bias=bias,
scale_residual=scale_residual,
)
self.conv_layers += [conv]
# define output layers
if self.use_last_conv:
self.last_conv = torch.nn.Sequential(
torch.nn.ReLU(inplace=True),
Conv1d1x1(skip_channels, skip_channels, bias=True),
torch.nn.ReLU(inplace=True),
Conv1d1x1(skip_channels, out_channels, bias=True),
)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
def forward(
self,
x: torch.Tensor,
x_mask: Optional[torch.Tensor] = None,
c: Optional[torch.Tensor] = None,
g: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T) if use_first_conv else
(B, residual_channels, T).
x_mask (Optional[Tensor]): Mask tensor (B, 1, T).
c (Optional[Tensor]): Local conditioning features (B, aux_channels, T).
g (Optional[Tensor]): Global conditioning features (B, global_channels, 1).
Returns:
Tensor: Output tensor (B, out_channels, T) if use_last_conv else
(B, residual_channels, T).
"""
# encode to hidden representation
if self.use_first_conv:
x = self.first_conv(x)
# residual block
skips = 0.0
for f in self.conv_layers:
x, h = f(x, x_mask=x_mask, c=c, g=g)
skips = skips + h
x = skips
if self.scale_skip_connect:
x = x * math.sqrt(1.0 / len(self.conv_layers))
# apply final layers
if self.use_last_conv:
x = self.last_conv(x)
return x
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m: torch.nn.Module):
try:
logging.debug(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
@staticmethod
def _get_receptive_field_size(
layers: int,
stacks: int,
kernel_size: int,
base_dilation: int,
) -> int:
assert layers % stacks == 0
layers_per_cycle = layers // stacks
dilations = [base_dilation ** (i % layers_per_cycle) for i in range(layers)]
return (kernel_size - 1) * sum(dilations) + 1
@property
def receptive_field_size(self) -> int:
"""Return receptive field size."""
return self._get_receptive_field_size(
self.layers, self.stacks, self.kernel_size, self.base_dilation
)
class Conv1d(torch.nn.Conv1d):
"""Conv1d module with customized initialization."""
def __init__(self, *args, **kwargs):
"""Initialize Conv1d module."""
super().__init__(*args, **kwargs)
def reset_parameters(self):
"""Reset parameters."""
torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
if self.bias is not None:
torch.nn.init.constant_(self.bias, 0.0)
class Conv1d1x1(Conv1d):
"""1x1 Conv1d with customized initialization."""
def __init__(self, in_channels: int, out_channels: int, bias: bool):
"""Initialize 1x1 Conv1d module."""
super().__init__(
in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias
)
class ResidualBlock(torch.nn.Module):
"""Residual block module in WaveNet."""
def __init__(
self,
kernel_size: int = 3,
residual_channels: int = 64,
gate_channels: int = 128,
skip_channels: int = 64,
aux_channels: int = 80,
global_channels: int = -1,
dropout_rate: float = 0.0,
dilation: int = 1,
bias: bool = True,
scale_residual: bool = False,
):
"""Initialize ResidualBlock module.
Args:
kernel_size (int): Kernel size of dilation convolution layer.
residual_channels (int): Number of channels for residual connection.
skip_channels (int): Number of channels for skip connection.
aux_channels (int): Number of local conditioning channels.
dropout (float): Dropout probability.
dilation (int): Dilation factor.
bias (bool): Whether to add bias parameter in convolution layers.
scale_residual (bool): Whether to scale the residual outputs.
"""
super().__init__()
self.dropout_rate = dropout_rate
self.residual_channels = residual_channels
self.skip_channels = skip_channels
self.scale_residual = scale_residual
# check
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
assert gate_channels % 2 == 0
# dilation conv
padding = (kernel_size - 1) // 2 * dilation
self.conv = Conv1d(
residual_channels,
gate_channels,
kernel_size,
padding=padding,
dilation=dilation,
bias=bias,
)
# local conditioning
if aux_channels > 0:
self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
else:
self.conv1x1_aux = None
# global conditioning
if global_channels > 0:
self.conv1x1_glo = Conv1d1x1(global_channels, gate_channels, bias=False)
else:
self.conv1x1_glo = None
# conv output is split into two groups
gate_out_channels = gate_channels // 2
# NOTE(kan-bayashi): concat two convs into a single conv for the efficiency
# (integrate res 1x1 + skip 1x1 convs)
self.conv1x1_out = Conv1d1x1(
gate_out_channels, residual_channels + skip_channels, bias=bias
)
def forward(
self,
x: torch.Tensor,
x_mask: Optional[torch.Tensor] = None,
c: Optional[torch.Tensor] = None,
g: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, residual_channels, T).
x_mask Optional[torch.Tensor]: Mask tensor (B, 1, T).
c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Output tensor for residual connection (B, residual_channels, T).
Tensor: Output tensor for skip connection (B, skip_channels, T).
"""
residual = x
x = F.dropout(x, p=self.dropout_rate, training=self.training)
x = self.conv(x)
# split into two part for gated activation
splitdim = 1
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
# local conditioning
if c is not None:
c = self.conv1x1_aux(c)
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
xa, xb = xa + ca, xb + cb
# global conditioning
if g is not None:
g = self.conv1x1_glo(g)
ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim)
xa, xb = xa + ga, xb + gb
x = torch.tanh(xa) * torch.sigmoid(xb)
# residual + skip 1x1 conv
x = self.conv1x1_out(x)
if x_mask is not None:
x = x * x_mask
# split integrated conv results
x, s = x.split([self.residual_channels, self.skip_channels], dim=1)
# for residual connection
x = x + residual
if self.scale_residual:
x = x * math.sqrt(0.5)
return x, s