mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
removed redundant files
This commit is contained in:
parent
b7efcbf154
commit
ee718f1da1
@ -1,81 +0,0 @@
|
||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/__init__.py
|
||||
|
||||
"""Maximum path calculation module.
|
||||
|
||||
This code is based on https://github.com/jaywalnut310/vits.
|
||||
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numba import njit, prange
|
||||
|
||||
try:
|
||||
from .core import maximum_path_c
|
||||
|
||||
is_cython_avalable = True
|
||||
except ImportError:
|
||||
is_cython_avalable = False
|
||||
warnings.warn(
|
||||
"Cython version is not available. Fallback to 'EXPERIMETAL' numba version. "
|
||||
"If you want to use the cython version, please build it as follows: "
|
||||
"`cd espnet2/gan_tts/vits/monotonic_align; python setup.py build_ext --inplace`"
|
||||
)
|
||||
|
||||
|
||||
def maximum_path(neg_x_ent: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate maximum path.
|
||||
|
||||
Args:
|
||||
neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text).
|
||||
attn_mask (Tensor): Attention mask (B, T_feats, T_text).
|
||||
|
||||
Returns:
|
||||
Tensor: Maximum path tensor (B, T_feats, T_text).
|
||||
|
||||
"""
|
||||
device, dtype = neg_x_ent.device, neg_x_ent.dtype
|
||||
neg_x_ent = neg_x_ent.cpu().numpy().astype(np.float32)
|
||||
path = np.zeros(neg_x_ent.shape, dtype=np.int32)
|
||||
t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32)
|
||||
t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32)
|
||||
if is_cython_avalable:
|
||||
maximum_path_c(path, neg_x_ent, t_t_max, t_s_max)
|
||||
else:
|
||||
maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max)
|
||||
|
||||
return torch.from_numpy(path).to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
@njit
|
||||
def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf):
|
||||
"""Calculate a single maximum path with numba."""
|
||||
index = t_x - 1
|
||||
for y in range(t_y):
|
||||
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
||||
if x == y:
|
||||
v_cur = max_neg_val
|
||||
else:
|
||||
v_cur = value[y - 1, x]
|
||||
if x == 0:
|
||||
if y == 0:
|
||||
v_prev = 0.0
|
||||
else:
|
||||
v_prev = max_neg_val
|
||||
else:
|
||||
v_prev = value[y - 1, x - 1]
|
||||
value[y, x] += max(v_prev, v_cur)
|
||||
|
||||
for y in range(t_y - 1, -1, -1):
|
||||
path[y, index] = 1
|
||||
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
|
||||
index = index - 1
|
||||
|
||||
|
||||
@njit(parallel=True)
|
||||
def maximum_path_numba(paths, values, t_ys, t_xs):
|
||||
"""Calculate batch maximum path with numba."""
|
||||
for i in prange(paths.shape[0]):
|
||||
maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i])
|
@ -1,51 +0,0 @@
|
||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/core.pyx
|
||||
|
||||
"""Maximum path calculation module with cython optimization.
|
||||
|
||||
This code is copied from https://github.com/jaywalnut310/vits and modifed code format.
|
||||
|
||||
"""
|
||||
|
||||
cimport cython
|
||||
|
||||
from cython.parallel import prange
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
|
||||
cdef int x
|
||||
cdef int y
|
||||
cdef float v_prev
|
||||
cdef float v_cur
|
||||
cdef float tmp
|
||||
cdef int index = t_x - 1
|
||||
|
||||
for y in range(t_y):
|
||||
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
||||
if x == y:
|
||||
v_cur = max_neg_val
|
||||
else:
|
||||
v_cur = value[y - 1, x]
|
||||
if x == 0:
|
||||
if y == 0:
|
||||
v_prev = 0.0
|
||||
else:
|
||||
v_prev = max_neg_val
|
||||
else:
|
||||
v_prev = value[y - 1, x - 1]
|
||||
value[y, x] += max(v_prev, v_cur)
|
||||
|
||||
for y in range(t_y - 1, -1, -1):
|
||||
path[y, index] = 1
|
||||
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
|
||||
index = index - 1
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
|
||||
cdef int b = paths.shape[0]
|
||||
cdef int i
|
||||
for i in prange(b, nogil=True):
|
||||
maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
|
@ -1,31 +0,0 @@
|
||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py
|
||||
"""Setup cython code."""
|
||||
|
||||
from Cython.Build import cythonize
|
||||
from setuptools import Extension, setup
|
||||
from setuptools.command.build_ext import build_ext as _build_ext
|
||||
|
||||
|
||||
class build_ext(_build_ext):
|
||||
"""Overwrite build_ext."""
|
||||
|
||||
def finalize_options(self):
|
||||
"""Prevent numpy from thinking it is still in its setup process."""
|
||||
_build_ext.finalize_options(self)
|
||||
__builtins__.__NUMPY_SETUP__ = False
|
||||
import numpy
|
||||
|
||||
self.include_dirs.append(numpy.get_include())
|
||||
|
||||
|
||||
exts = [
|
||||
Extension(
|
||||
name="core",
|
||||
sources=["core.pyx"],
|
||||
)
|
||||
]
|
||||
setup(
|
||||
name="monotonic_align",
|
||||
ext_modules=cythonize(exts, language_level=3),
|
||||
cmdclass={"build_ext": build_ext},
|
||||
)
|
@ -1,218 +0,0 @@
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py
|
||||
|
||||
"""Flow-related transformation.
|
||||
|
||||
This code is derived from https://github.com/bayesiains/nflows.
|
||||
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
||||
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
||||
DEFAULT_MIN_DERIVATIVE = 1e-3
|
||||
|
||||
|
||||
# TODO(kan-bayashi): Documentation and type hint
|
||||
def piecewise_rational_quadratic_transform(
|
||||
inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
tails=None,
|
||||
tail_bound=1.0,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||
):
|
||||
if tails is None:
|
||||
spline_fn = rational_quadratic_spline
|
||||
spline_kwargs = {}
|
||||
else:
|
||||
spline_fn = unconstrained_rational_quadratic_spline
|
||||
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
|
||||
|
||||
outputs, logabsdet = spline_fn(
|
||||
inputs=inputs,
|
||||
unnormalized_widths=unnormalized_widths,
|
||||
unnormalized_heights=unnormalized_heights,
|
||||
unnormalized_derivatives=unnormalized_derivatives,
|
||||
inverse=inverse,
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative,
|
||||
**spline_kwargs
|
||||
)
|
||||
return outputs, logabsdet
|
||||
|
||||
|
||||
# TODO(kan-bayashi): Documentation and type hint
|
||||
def unconstrained_rational_quadratic_spline(
|
||||
inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
tails="linear",
|
||||
tail_bound=1.0,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||
):
|
||||
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
||||
outside_interval_mask = ~inside_interval_mask
|
||||
|
||||
outputs = torch.zeros_like(inputs)
|
||||
logabsdet = torch.zeros_like(inputs)
|
||||
|
||||
if tails == "linear":
|
||||
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
||||
constant = np.log(np.exp(1 - min_derivative) - 1)
|
||||
unnormalized_derivatives[..., 0] = constant
|
||||
unnormalized_derivatives[..., -1] = constant
|
||||
|
||||
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
||||
logabsdet[outside_interval_mask] = 0
|
||||
else:
|
||||
raise RuntimeError("{} tails are not implemented.".format(tails))
|
||||
|
||||
(
|
||||
outputs[inside_interval_mask],
|
||||
logabsdet[inside_interval_mask],
|
||||
) = rational_quadratic_spline(
|
||||
inputs=inputs[inside_interval_mask],
|
||||
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
||||
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
||||
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
||||
inverse=inverse,
|
||||
left=-tail_bound,
|
||||
right=tail_bound,
|
||||
bottom=-tail_bound,
|
||||
top=tail_bound,
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative,
|
||||
)
|
||||
|
||||
return outputs, logabsdet
|
||||
|
||||
|
||||
# TODO(kan-bayashi): Documentation and type hint
|
||||
def rational_quadratic_spline(
|
||||
inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
left=0.0,
|
||||
right=1.0,
|
||||
bottom=0.0,
|
||||
top=1.0,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||
):
|
||||
if torch.min(inputs) < left or torch.max(inputs) > right:
|
||||
raise ValueError("Input to a transform is not within its domain")
|
||||
|
||||
num_bins = unnormalized_widths.shape[-1]
|
||||
|
||||
if min_bin_width * num_bins > 1.0:
|
||||
raise ValueError("Minimal bin width too large for the number of bins")
|
||||
if min_bin_height * num_bins > 1.0:
|
||||
raise ValueError("Minimal bin height too large for the number of bins")
|
||||
|
||||
widths = F.softmax(unnormalized_widths, dim=-1)
|
||||
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
||||
cumwidths = torch.cumsum(widths, dim=-1)
|
||||
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
|
||||
cumwidths = (right - left) * cumwidths + left
|
||||
cumwidths[..., 0] = left
|
||||
cumwidths[..., -1] = right
|
||||
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
||||
|
||||
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
||||
|
||||
heights = F.softmax(unnormalized_heights, dim=-1)
|
||||
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
||||
cumheights = torch.cumsum(heights, dim=-1)
|
||||
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
|
||||
cumheights = (top - bottom) * cumheights + bottom
|
||||
cumheights[..., 0] = bottom
|
||||
cumheights[..., -1] = top
|
||||
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
||||
|
||||
if inverse:
|
||||
bin_idx = _searchsorted(cumheights, inputs)[..., None]
|
||||
else:
|
||||
bin_idx = _searchsorted(cumwidths, inputs)[..., None]
|
||||
|
||||
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
||||
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
||||
delta = heights / widths
|
||||
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
||||
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
||||
|
||||
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
if inverse:
|
||||
a = (inputs - input_cumheights) * (
|
||||
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||
) + input_heights * (input_delta - input_derivatives)
|
||||
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
|
||||
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||
)
|
||||
c = -input_delta * (inputs - input_cumheights)
|
||||
|
||||
discriminant = b.pow(2) - 4 * a * c
|
||||
assert (discriminant >= 0).all()
|
||||
|
||||
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
||||
outputs = root * input_bin_widths + input_cumwidths
|
||||
|
||||
theta_one_minus_theta = root * (1 - root)
|
||||
denominator = input_delta + (
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta
|
||||
)
|
||||
derivative_numerator = input_delta.pow(2) * (
|
||||
input_derivatives_plus_one * root.pow(2)
|
||||
+ 2 * input_delta * theta_one_minus_theta
|
||||
+ input_derivatives * (1 - root).pow(2)
|
||||
)
|
||||
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||
|
||||
return outputs, -logabsdet
|
||||
else:
|
||||
theta = (inputs - input_cumwidths) / input_bin_widths
|
||||
theta_one_minus_theta = theta * (1 - theta)
|
||||
|
||||
numerator = input_heights * (
|
||||
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
|
||||
)
|
||||
denominator = input_delta + (
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta
|
||||
)
|
||||
outputs = input_cumheights + numerator / denominator
|
||||
|
||||
derivative_numerator = input_delta.pow(2) * (
|
||||
input_derivatives_plus_one * theta.pow(2)
|
||||
+ 2 * input_delta * theta_one_minus_theta
|
||||
+ input_derivatives * (1 - theta).pow(2)
|
||||
)
|
||||
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||
|
||||
return outputs, logabsdet
|
||||
|
||||
|
||||
def _searchsorted(bin_locations, inputs, eps=1e-6):
|
||||
bin_locations[..., -1] += eps
|
||||
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
@ -1,265 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import collections
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from pathlib import Path
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
|
||||
def get_random_segments(
|
||||
x: torch.Tensor,
|
||||
x_lengths: torch.Tensor,
|
||||
segment_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Get random segments.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, C, T).
|
||||
x_lengths (Tensor): Length tensor (B,).
|
||||
segment_size (int): Segment size.
|
||||
|
||||
Returns:
|
||||
Tensor: Segmented tensor (B, C, segment_size).
|
||||
Tensor: Start index tensor (B,).
|
||||
|
||||
"""
|
||||
b, c, t = x.size()
|
||||
max_start_idx = x_lengths - segment_size
|
||||
max_start_idx[max_start_idx < 0] = 0
|
||||
start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to(
|
||||
dtype=torch.long,
|
||||
)
|
||||
segments = get_segments(x, start_idxs, segment_size)
|
||||
|
||||
return segments, start_idxs
|
||||
|
||||
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
|
||||
def get_segments(
|
||||
x: torch.Tensor,
|
||||
start_idxs: torch.Tensor,
|
||||
segment_size: int,
|
||||
) -> torch.Tensor:
|
||||
"""Get segments.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, C, T).
|
||||
start_idxs (Tensor): Start index tensor (B,).
|
||||
segment_size (int): Segment size.
|
||||
|
||||
Returns:
|
||||
Tensor: Segmented tensor (B, C, segment_size).
|
||||
|
||||
"""
|
||||
b, c, t = x.size()
|
||||
segments = x.new_zeros(b, c, segment_size)
|
||||
for i, start_idx in enumerate(start_idxs):
|
||||
segments[i] = x[i, :, start_idx : start_idx + segment_size]
|
||||
return segments
|
||||
|
||||
|
||||
# from https://github.com/jaywalnut310/vit://github.com/jaywalnut310/vits/blob/main/commons.py
|
||||
def intersperse(sequence, item=0):
|
||||
result = [item] * (len(sequence) * 2 + 1)
|
||||
result[1::2] = sequence
|
||||
return result
|
||||
|
||||
|
||||
# from https://github.com/jaywalnut310/vits/blob/main/utils.py
|
||||
MATPLOTLIB_FLAG = False
|
||||
|
||||
|
||||
def plot_feature(spectrogram):
|
||||
global MATPLOTLIB_FLAG
|
||||
if not MATPLOTLIB_FLAG:
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
MATPLOTLIB_FLAG = True
|
||||
mpl_logger = logging.getLogger("matplotlib")
|
||||
mpl_logger.setLevel(logging.WARNING)
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
||||
plt.colorbar(im, ax=ax)
|
||||
plt.xlabel("Frames")
|
||||
plt.ylabel("Channels")
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
||||
|
||||
|
||||
class MetricsTracker(collections.defaultdict):
|
||||
def __init__(self):
|
||||
# Passing the type 'int' to the base-class constructor
|
||||
# makes undefined items default to int() which is zero.
|
||||
# This class will play a role as metrics tracker.
|
||||
# It can record many metrics, including but not limited to loss.
|
||||
super(MetricsTracker, self).__init__(int)
|
||||
|
||||
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
|
||||
ans = MetricsTracker()
|
||||
for k, v in self.items():
|
||||
ans[k] = v
|
||||
for k, v in other.items():
|
||||
ans[k] = ans[k] + v
|
||||
return ans
|
||||
|
||||
def __mul__(self, alpha: float) -> "MetricsTracker":
|
||||
ans = MetricsTracker()
|
||||
for k, v in self.items():
|
||||
ans[k] = v * alpha
|
||||
return ans
|
||||
|
||||
def __str__(self) -> str:
|
||||
ans = ""
|
||||
for k, v in self.norm_items():
|
||||
norm_value = "%.4g" % v
|
||||
ans += str(k) + "=" + str(norm_value) + ", "
|
||||
samples = "%.2f" % self["samples"]
|
||||
ans += "over " + str(samples) + " samples."
|
||||
return ans
|
||||
|
||||
def norm_items(self) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Returns a list of pairs, like:
|
||||
[('loss_1', 0.1), ('loss_2', 0.07)]
|
||||
"""
|
||||
samples = self["samples"] if "samples" in self else 1
|
||||
ans = []
|
||||
for k, v in self.items():
|
||||
if k == "samples":
|
||||
continue
|
||||
norm_value = float(v) / samples
|
||||
ans.append((k, norm_value))
|
||||
return ans
|
||||
|
||||
def reduce(self, device):
|
||||
"""
|
||||
Reduce using torch.distributed, which I believe ensures that
|
||||
all processes get the total.
|
||||
"""
|
||||
keys = sorted(self.keys())
|
||||
s = torch.tensor([float(self[k]) for k in keys], device=device)
|
||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||
for k, v in zip(keys, s.cpu().tolist()):
|
||||
self[k] = v
|
||||
|
||||
def write_summary(
|
||||
self,
|
||||
tb_writer: SummaryWriter,
|
||||
prefix: str,
|
||||
batch_idx: int,
|
||||
) -> None:
|
||||
"""Add logging information to a TensorBoard writer.
|
||||
|
||||
Args:
|
||||
tb_writer: a TensorBoard writer
|
||||
prefix: a prefix for the name of the loss, e.g. "train/valid_",
|
||||
or "train/current_"
|
||||
batch_idx: The current batch index, used as the x-axis of the plot.
|
||||
"""
|
||||
for k, v in self.norm_items():
|
||||
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
||||
|
||||
|
||||
# checkpoint saving and loading
|
||||
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
filename: Path,
|
||||
model: Union[nn.Module, DDP],
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
optimizer_g: Optional[Optimizer] = None,
|
||||
optimizer_d: Optional[Optimizer] = None,
|
||||
scheduler_g: Optional[LRSchedulerType] = None,
|
||||
scheduler_d: Optional[LRSchedulerType] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save training information to a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
The checkpoint filename.
|
||||
model:
|
||||
The model to be saved. We only save its `state_dict()`.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
params:
|
||||
User defined parameters, e.g., epoch, loss.
|
||||
optimizer_g:
|
||||
The optimizer for generator used in the training.
|
||||
Its `state_dict` will be saved.
|
||||
optimizer_d:
|
||||
The optimizer for discriminator used in the training.
|
||||
Its `state_dict` will be saved.
|
||||
scheduler_g:
|
||||
The learning rate scheduler for generator used in the training.
|
||||
Its `state_dict` will be saved.
|
||||
scheduler_d:
|
||||
The learning rate scheduler for discriminator used in the training.
|
||||
Its `state_dict` will be saved.
|
||||
scalar:
|
||||
The GradScaler to be saved. We only save its `state_dict()`.
|
||||
rank:
|
||||
Used in DDP. We save checkpoint only for the node whose rank is 0.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
|
||||
logging.info(f"Saving checkpoint to {filename}")
|
||||
|
||||
if isinstance(model, DDP):
|
||||
model = model.module
|
||||
|
||||
checkpoint = {
|
||||
"model": model.state_dict(),
|
||||
"optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None,
|
||||
"optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None,
|
||||
"scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None,
|
||||
"scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None,
|
||||
"grad_scaler": scaler.state_dict() if scaler is not None else None,
|
||||
"sampler": sampler.state_dict() if sampler is not None else None,
|
||||
}
|
||||
|
||||
if params:
|
||||
for k, v in params.items():
|
||||
assert k not in checkpoint
|
||||
checkpoint[k] = v
|
||||
|
||||
torch.save(checkpoint, filename)
|
@ -1,609 +0,0 @@
|
||||
# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""VITS module for GAN-TTS task."""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from hifigan import (
|
||||
HiFiGANMultiPeriodDiscriminator,
|
||||
HiFiGANMultiScaleDiscriminator,
|
||||
HiFiGANMultiScaleMultiPeriodDiscriminator,
|
||||
HiFiGANPeriodDiscriminator,
|
||||
HiFiGANScaleDiscriminator,
|
||||
)
|
||||
from loss import (
|
||||
DiscriminatorAdversarialLoss,
|
||||
FeatureMatchLoss,
|
||||
GeneratorAdversarialLoss,
|
||||
KLDivergenceLoss,
|
||||
MelSpectrogramLoss,
|
||||
)
|
||||
from utils import get_segments
|
||||
from generator import VITSGenerator
|
||||
|
||||
|
||||
AVAILABLE_GENERATERS = {
|
||||
"vits_generator": VITSGenerator,
|
||||
}
|
||||
AVAILABLE_DISCRIMINATORS = {
|
||||
"hifigan_period_discriminator": HiFiGANPeriodDiscriminator,
|
||||
"hifigan_scale_discriminator": HiFiGANScaleDiscriminator,
|
||||
"hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator,
|
||||
"hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator,
|
||||
"hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
|
||||
}
|
||||
|
||||
|
||||
class VITS(nn.Module):
|
||||
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# generator related
|
||||
vocab_size: int,
|
||||
feature_dim: int = 513,
|
||||
sampling_rate: int = 22050,
|
||||
generator_type: str = "vits_generator",
|
||||
generator_params: Dict[str, Any] = {
|
||||
"hidden_channels": 192,
|
||||
"spks": None,
|
||||
"langs": None,
|
||||
"spk_embed_dim": None,
|
||||
"global_channels": -1,
|
||||
"segment_size": 32,
|
||||
"text_encoder_attention_heads": 2,
|
||||
"text_encoder_ffn_expand": 4,
|
||||
"text_encoder_cnn_module_kernel": 5,
|
||||
"text_encoder_blocks": 6,
|
||||
"text_encoder_dropout_rate": 0.1,
|
||||
"decoder_kernel_size": 7,
|
||||
"decoder_channels": 512,
|
||||
"decoder_upsample_scales": [8, 8, 2, 2],
|
||||
"decoder_upsample_kernel_sizes": [16, 16, 4, 4],
|
||||
"decoder_resblock_kernel_sizes": [3, 7, 11],
|
||||
"decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"use_weight_norm_in_decoder": True,
|
||||
"posterior_encoder_kernel_size": 5,
|
||||
"posterior_encoder_layers": 16,
|
||||
"posterior_encoder_stacks": 1,
|
||||
"posterior_encoder_base_dilation": 1,
|
||||
"posterior_encoder_dropout_rate": 0.0,
|
||||
"use_weight_norm_in_posterior_encoder": True,
|
||||
"flow_flows": 4,
|
||||
"flow_kernel_size": 5,
|
||||
"flow_base_dilation": 1,
|
||||
"flow_layers": 4,
|
||||
"flow_dropout_rate": 0.0,
|
||||
"use_weight_norm_in_flow": True,
|
||||
"use_only_mean_in_flow": True,
|
||||
"stochastic_duration_predictor_kernel_size": 3,
|
||||
"stochastic_duration_predictor_dropout_rate": 0.5,
|
||||
"stochastic_duration_predictor_flows": 4,
|
||||
"stochastic_duration_predictor_dds_conv_layers": 3,
|
||||
},
|
||||
# discriminator related
|
||||
discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator",
|
||||
discriminator_params: Dict[str, Any] = {
|
||||
"scales": 1,
|
||||
"scale_downsample_pooling": "AvgPool1d",
|
||||
"scale_downsample_pooling_params": {
|
||||
"kernel_size": 4,
|
||||
"stride": 2,
|
||||
"padding": 2,
|
||||
},
|
||||
"scale_discriminator_params": {
|
||||
"in_channels": 1,
|
||||
"out_channels": 1,
|
||||
"kernel_sizes": [15, 41, 5, 3],
|
||||
"channels": 128,
|
||||
"max_downsample_channels": 1024,
|
||||
"max_groups": 16,
|
||||
"bias": True,
|
||||
"downsample_scales": [2, 2, 4, 4, 1],
|
||||
"nonlinear_activation": "LeakyReLU",
|
||||
"nonlinear_activation_params": {"negative_slope": 0.1},
|
||||
"use_weight_norm": True,
|
||||
"use_spectral_norm": False,
|
||||
},
|
||||
"follow_official_norm": False,
|
||||
"periods": [2, 3, 5, 7, 11],
|
||||
"period_discriminator_params": {
|
||||
"in_channels": 1,
|
||||
"out_channels": 1,
|
||||
"kernel_sizes": [5, 3],
|
||||
"channels": 32,
|
||||
"downsample_scales": [3, 3, 3, 3, 1],
|
||||
"max_downsample_channels": 1024,
|
||||
"bias": True,
|
||||
"nonlinear_activation": "LeakyReLU",
|
||||
"nonlinear_activation_params": {"negative_slope": 0.1},
|
||||
"use_weight_norm": True,
|
||||
"use_spectral_norm": False,
|
||||
},
|
||||
},
|
||||
# loss related
|
||||
generator_adv_loss_params: Dict[str, Any] = {
|
||||
"average_by_discriminators": False,
|
||||
"loss_type": "mse",
|
||||
},
|
||||
discriminator_adv_loss_params: Dict[str, Any] = {
|
||||
"average_by_discriminators": False,
|
||||
"loss_type": "mse",
|
||||
},
|
||||
feat_match_loss_params: Dict[str, Any] = {
|
||||
"average_by_discriminators": False,
|
||||
"average_by_layers": False,
|
||||
"include_final_outputs": True,
|
||||
},
|
||||
mel_loss_params: Dict[str, Any] = {
|
||||
"frame_shift": 256,
|
||||
"frame_length": 1024,
|
||||
"n_mels": 80,
|
||||
},
|
||||
lambda_adv: float = 1.0,
|
||||
lambda_mel: float = 45.0,
|
||||
lambda_feat_match: float = 2.0,
|
||||
lambda_dur: float = 1.0,
|
||||
lambda_kl: float = 1.0,
|
||||
cache_generator_outputs: bool = True,
|
||||
):
|
||||
"""Initialize VITS module.
|
||||
|
||||
Args:
|
||||
idim (int): Input vocabrary size.
|
||||
odim (int): Acoustic feature dimension. The actual output channels will
|
||||
be 1 since VITS is the end-to-end text-to-wave model but for the
|
||||
compatibility odim is used to indicate the acoustic feature dimension.
|
||||
sampling_rate (int): Sampling rate, not used for the training but it will
|
||||
be referred in saving waveform during the inference.
|
||||
generator_type (str): Generator type.
|
||||
generator_params (Dict[str, Any]): Parameter dict for generator.
|
||||
discriminator_type (str): Discriminator type.
|
||||
discriminator_params (Dict[str, Any]): Parameter dict for discriminator.
|
||||
generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator
|
||||
adversarial loss.
|
||||
discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for
|
||||
discriminator adversarial loss.
|
||||
feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss.
|
||||
mel_loss_params (Dict[str, Any]): Parameter dict for mel loss.
|
||||
lambda_adv (float): Loss scaling coefficient for adversarial loss.
|
||||
lambda_mel (float): Loss scaling coefficient for mel spectrogram loss.
|
||||
lambda_feat_match (float): Loss scaling coefficient for feat match loss.
|
||||
lambda_dur (float): Loss scaling coefficient for duration loss.
|
||||
lambda_kl (float): Loss scaling coefficient for KL divergence loss.
|
||||
cache_generator_outputs (bool): Whether to cache generator outputs.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# define modules
|
||||
generator_class = AVAILABLE_GENERATERS[generator_type]
|
||||
if generator_type == "vits_generator":
|
||||
# NOTE(kan-bayashi): Update parameters for the compatibility.
|
||||
# The idim and odim is automatically decided from input data,
|
||||
# where idim represents #vocabularies and odim represents
|
||||
# the input acoustic feature dimension.
|
||||
generator_params.update(vocabs=vocab_size, aux_channels=feature_dim)
|
||||
self.generator = generator_class(
|
||||
**generator_params,
|
||||
)
|
||||
discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type]
|
||||
self.discriminator = discriminator_class(
|
||||
**discriminator_params,
|
||||
)
|
||||
self.generator_adv_loss = GeneratorAdversarialLoss(
|
||||
**generator_adv_loss_params,
|
||||
)
|
||||
self.discriminator_adv_loss = DiscriminatorAdversarialLoss(
|
||||
**discriminator_adv_loss_params,
|
||||
)
|
||||
self.feat_match_loss = FeatureMatchLoss(
|
||||
**feat_match_loss_params,
|
||||
)
|
||||
mel_loss_params.update(sampling_rate=sampling_rate)
|
||||
self.mel_loss = MelSpectrogramLoss(
|
||||
**mel_loss_params,
|
||||
)
|
||||
self.kl_loss = KLDivergenceLoss()
|
||||
|
||||
# coefficients
|
||||
self.lambda_adv = lambda_adv
|
||||
self.lambda_mel = lambda_mel
|
||||
self.lambda_kl = lambda_kl
|
||||
self.lambda_feat_match = lambda_feat_match
|
||||
self.lambda_dur = lambda_dur
|
||||
|
||||
# cache
|
||||
self.cache_generator_outputs = cache_generator_outputs
|
||||
self._cache = None
|
||||
|
||||
# store sampling rate for saving wav file
|
||||
# (not used for the training)
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
# store parameters for test compatibility
|
||||
self.spks = self.generator.spks
|
||||
self.langs = self.generator.langs
|
||||
self.spk_embed_dim = self.generator.spk_embed_dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
feats: torch.Tensor,
|
||||
feats_lengths: torch.Tensor,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
return_sample: bool = False,
|
||||
sids: Optional[torch.Tensor] = None,
|
||||
spembs: Optional[torch.Tensor] = None,
|
||||
lids: Optional[torch.Tensor] = None,
|
||||
forward_generator: bool = True,
|
||||
) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||
"""Perform generator forward.
|
||||
|
||||
Args:
|
||||
text (Tensor): Text index tensor (B, T_text).
|
||||
text_lengths (Tensor): Text length tensor (B,).
|
||||
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
|
||||
feats_lengths (Tensor): Feature length tensor (B,).
|
||||
speech (Tensor): Speech waveform tensor (B, T_wav).
|
||||
speech_lengths (Tensor): Speech length tensor (B,).
|
||||
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||
forward_generator (bool): Whether to forward generator.
|
||||
|
||||
Returns:
|
||||
- loss (Tensor): Loss scalar tensor.
|
||||
- stats (Dict[str, float]): Statistics to be monitored.
|
||||
"""
|
||||
if forward_generator:
|
||||
return self._forward_generator(
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
feats=feats,
|
||||
feats_lengths=feats_lengths,
|
||||
speech=speech,
|
||||
speech_lengths=speech_lengths,
|
||||
return_sample=return_sample,
|
||||
sids=sids,
|
||||
spembs=spembs,
|
||||
lids=lids,
|
||||
)
|
||||
else:
|
||||
return self._forward_discrminator(
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
feats=feats,
|
||||
feats_lengths=feats_lengths,
|
||||
speech=speech,
|
||||
speech_lengths=speech_lengths,
|
||||
sids=sids,
|
||||
spembs=spembs,
|
||||
lids=lids,
|
||||
)
|
||||
|
||||
def _forward_generator(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
feats: torch.Tensor,
|
||||
feats_lengths: torch.Tensor,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
return_sample: bool = False,
|
||||
sids: Optional[torch.Tensor] = None,
|
||||
spembs: Optional[torch.Tensor] = None,
|
||||
lids: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||
"""Perform generator forward.
|
||||
|
||||
Args:
|
||||
text (Tensor): Text index tensor (B, T_text).
|
||||
text_lengths (Tensor): Text length tensor (B,).
|
||||
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
|
||||
feats_lengths (Tensor): Feature length tensor (B,).
|
||||
speech (Tensor): Speech waveform tensor (B, T_wav).
|
||||
speech_lengths (Tensor): Speech length tensor (B,).
|
||||
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||
|
||||
Returns:
|
||||
* loss (Tensor): Loss scalar tensor.
|
||||
* stats (Dict[str, float]): Statistics to be monitored.
|
||||
"""
|
||||
# setup
|
||||
feats = feats.transpose(1, 2)
|
||||
speech = speech.unsqueeze(1)
|
||||
|
||||
# calculate generator outputs
|
||||
reuse_cache = True
|
||||
if not self.cache_generator_outputs or self._cache is None:
|
||||
reuse_cache = False
|
||||
outs = self.generator(
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
feats=feats,
|
||||
feats_lengths=feats_lengths,
|
||||
sids=sids,
|
||||
spembs=spembs,
|
||||
lids=lids,
|
||||
)
|
||||
else:
|
||||
outs = self._cache
|
||||
|
||||
# store cache
|
||||
if self.training and self.cache_generator_outputs and not reuse_cache:
|
||||
self._cache = outs
|
||||
|
||||
# parse outputs
|
||||
speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
|
||||
_, z_p, m_p, logs_p, _, logs_q = outs_
|
||||
speech_ = get_segments(
|
||||
x=speech,
|
||||
start_idxs=start_idxs * self.generator.upsample_factor,
|
||||
segment_size=self.generator.segment_size * self.generator.upsample_factor,
|
||||
)
|
||||
|
||||
# calculate discriminator outputs
|
||||
p_hat = self.discriminator(speech_hat_)
|
||||
with torch.no_grad():
|
||||
# do not store discriminator gradient in generator turn
|
||||
p = self.discriminator(speech_)
|
||||
|
||||
# calculate losses
|
||||
with autocast(enabled=False):
|
||||
if not return_sample:
|
||||
mel_loss = self.mel_loss(speech_hat_, speech_)
|
||||
else:
|
||||
mel_loss, (mel_hat_, mel_) = self.mel_loss(
|
||||
speech_hat_, speech_, return_mel=True
|
||||
)
|
||||
kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
|
||||
dur_loss = torch.sum(dur_nll.float())
|
||||
adv_loss = self.generator_adv_loss(p_hat)
|
||||
feat_match_loss = self.feat_match_loss(p_hat, p)
|
||||
|
||||
mel_loss = mel_loss * self.lambda_mel
|
||||
kl_loss = kl_loss * self.lambda_kl
|
||||
dur_loss = dur_loss * self.lambda_dur
|
||||
adv_loss = adv_loss * self.lambda_adv
|
||||
feat_match_loss = feat_match_loss * self.lambda_feat_match
|
||||
loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss
|
||||
|
||||
stats = dict(
|
||||
generator_loss=loss.item(),
|
||||
generator_mel_loss=mel_loss.item(),
|
||||
generator_kl_loss=kl_loss.item(),
|
||||
generator_dur_loss=dur_loss.item(),
|
||||
generator_adv_loss=adv_loss.item(),
|
||||
generator_feat_match_loss=feat_match_loss.item(),
|
||||
)
|
||||
|
||||
if return_sample:
|
||||
stats["returned_sample"] = (
|
||||
speech_hat_[0].data.cpu().numpy(),
|
||||
speech_[0].data.cpu().numpy(),
|
||||
mel_hat_[0].data.cpu().numpy(),
|
||||
mel_[0].data.cpu().numpy(),
|
||||
)
|
||||
|
||||
# reset cache
|
||||
if reuse_cache or not self.training:
|
||||
self._cache = None
|
||||
|
||||
return loss, stats
|
||||
|
||||
def _forward_discrminator(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
feats: torch.Tensor,
|
||||
feats_lengths: torch.Tensor,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
sids: Optional[torch.Tensor] = None,
|
||||
spembs: Optional[torch.Tensor] = None,
|
||||
lids: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||
"""Perform discriminator forward.
|
||||
|
||||
Args:
|
||||
text (Tensor): Text index tensor (B, T_text).
|
||||
text_lengths (Tensor): Text length tensor (B,).
|
||||
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
|
||||
feats_lengths (Tensor): Feature length tensor (B,).
|
||||
speech (Tensor): Speech waveform tensor (B, T_wav).
|
||||
speech_lengths (Tensor): Speech length tensor (B,).
|
||||
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||
|
||||
Returns:
|
||||
* loss (Tensor): Loss scalar tensor.
|
||||
* stats (Dict[str, float]): Statistics to be monitored.
|
||||
"""
|
||||
# setup
|
||||
feats = feats.transpose(1, 2)
|
||||
speech = speech.unsqueeze(1)
|
||||
|
||||
# calculate generator outputs
|
||||
reuse_cache = True
|
||||
if not self.cache_generator_outputs or self._cache is None:
|
||||
reuse_cache = False
|
||||
outs = self.generator(
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
feats=feats,
|
||||
feats_lengths=feats_lengths,
|
||||
sids=sids,
|
||||
spembs=spembs,
|
||||
lids=lids,
|
||||
)
|
||||
else:
|
||||
outs = self._cache
|
||||
|
||||
# store cache
|
||||
if self.cache_generator_outputs and not reuse_cache:
|
||||
self._cache = outs
|
||||
|
||||
# parse outputs
|
||||
speech_hat_, _, _, start_idxs, *_ = outs
|
||||
speech_ = get_segments(
|
||||
x=speech,
|
||||
start_idxs=start_idxs * self.generator.upsample_factor,
|
||||
segment_size=self.generator.segment_size * self.generator.upsample_factor,
|
||||
)
|
||||
|
||||
# calculate discriminator outputs
|
||||
p_hat = self.discriminator(speech_hat_.detach())
|
||||
p = self.discriminator(speech_)
|
||||
|
||||
# calculate losses
|
||||
with autocast(enabled=False):
|
||||
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
|
||||
loss = real_loss + fake_loss
|
||||
|
||||
stats = dict(
|
||||
discriminator_loss=loss.item(),
|
||||
discriminator_real_loss=real_loss.item(),
|
||||
discriminator_fake_loss=fake_loss.item(),
|
||||
)
|
||||
|
||||
# reset cache
|
||||
if reuse_cache or not self.training:
|
||||
self._cache = None
|
||||
|
||||
return loss, stats
|
||||
|
||||
def inference(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
feats: Optional[torch.Tensor] = None,
|
||||
sids: Optional[torch.Tensor] = None,
|
||||
spembs: Optional[torch.Tensor] = None,
|
||||
lids: Optional[torch.Tensor] = None,
|
||||
durations: Optional[torch.Tensor] = None,
|
||||
noise_scale: float = 0.667,
|
||||
noise_scale_dur: float = 0.8,
|
||||
alpha: float = 1.0,
|
||||
max_len: Optional[int] = None,
|
||||
use_teacher_forcing: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Run inference for single sample.
|
||||
|
||||
Args:
|
||||
text (Tensor): Input text index tensor (T_text,).
|
||||
feats (Tensor): Feature tensor (T_feats, aux_channels).
|
||||
sids (Tensor): Speaker index tensor (1,).
|
||||
spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,).
|
||||
lids (Tensor): Language index tensor (1,).
|
||||
durations (Tensor): Ground-truth duration tensor (T_text,).
|
||||
noise_scale (float): Noise scale value for flow.
|
||||
noise_scale_dur (float): Noise scale value for duration predictor.
|
||||
alpha (float): Alpha parameter to control the speed of generated speech.
|
||||
max_len (Optional[int]): Maximum length.
|
||||
use_teacher_forcing (bool): Whether to use teacher forcing.
|
||||
|
||||
Returns:
|
||||
* wav (Tensor): Generated waveform tensor (T_wav,).
|
||||
* att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
|
||||
* duration (Tensor): Predicted duration tensor (T_text,).
|
||||
"""
|
||||
# setup
|
||||
text = text[None]
|
||||
text_lengths = torch.tensor(
|
||||
[text.size(1)],
|
||||
dtype=torch.long,
|
||||
device=text.device,
|
||||
)
|
||||
if sids is not None:
|
||||
sids = sids.view(1)
|
||||
if lids is not None:
|
||||
lids = lids.view(1)
|
||||
if durations is not None:
|
||||
durations = durations.view(1, 1, -1)
|
||||
|
||||
# inference
|
||||
if use_teacher_forcing:
|
||||
assert feats is not None
|
||||
feats = feats[None].transpose(1, 2)
|
||||
feats_lengths = torch.tensor(
|
||||
[feats.size(2)],
|
||||
dtype=torch.long,
|
||||
device=feats.device,
|
||||
)
|
||||
wav, att_w, dur = self.generator.inference(
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
feats=feats,
|
||||
feats_lengths=feats_lengths,
|
||||
sids=sids,
|
||||
spembs=spembs,
|
||||
lids=lids,
|
||||
max_len=max_len,
|
||||
use_teacher_forcing=use_teacher_forcing,
|
||||
)
|
||||
else:
|
||||
wav, att_w, dur = self.generator.inference(
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
sids=sids,
|
||||
spembs=spembs,
|
||||
lids=lids,
|
||||
dur=durations,
|
||||
noise_scale=noise_scale,
|
||||
noise_scale_dur=noise_scale_dur,
|
||||
alpha=alpha,
|
||||
max_len=max_len,
|
||||
)
|
||||
return wav.view(-1), att_w[0], dur[0]
|
||||
|
||||
def inference_batch(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
sids: Optional[torch.Tensor] = None,
|
||||
durations: Optional[torch.Tensor] = None,
|
||||
noise_scale: float = 0.667,
|
||||
noise_scale_dur: float = 0.8,
|
||||
alpha: float = 1.0,
|
||||
max_len: Optional[int] = None,
|
||||
use_teacher_forcing: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Run inference for one batch.
|
||||
|
||||
Args:
|
||||
text (Tensor): Input text index tensor (B, T_text).
|
||||
text_lengths (Tensor): Input text index tensor (B,).
|
||||
sids (Tensor): Speaker index tensor (B,).
|
||||
noise_scale (float): Noise scale value for flow.
|
||||
noise_scale_dur (float): Noise scale value for duration predictor.
|
||||
alpha (float): Alpha parameter to control the speed of generated speech.
|
||||
max_len (Optional[int]): Maximum length.
|
||||
|
||||
Returns:
|
||||
* wav (Tensor): Generated waveform tensor (B, T_wav).
|
||||
* att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text).
|
||||
* duration (Tensor): Predicted duration tensor (B, T_text).
|
||||
"""
|
||||
# inference
|
||||
wav, att_w, dur = self.generator.inference(
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
sids=sids,
|
||||
noise_scale=noise_scale,
|
||||
noise_scale_dur=noise_scale_dur,
|
||||
alpha=alpha,
|
||||
max_len=max_len,
|
||||
)
|
||||
return wav, att_w, dur
|
@ -1,349 +0,0 @@
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""WaveNet modules.
|
||||
|
||||
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
||||
|
||||
"""
|
||||
|
||||
import math
|
||||
import logging
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class WaveNet(torch.nn.Module):
|
||||
"""WaveNet with global conditioning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 1,
|
||||
out_channels: int = 1,
|
||||
kernel_size: int = 3,
|
||||
layers: int = 30,
|
||||
stacks: int = 3,
|
||||
base_dilation: int = 2,
|
||||
residual_channels: int = 64,
|
||||
aux_channels: int = -1,
|
||||
gate_channels: int = 128,
|
||||
skip_channels: int = 64,
|
||||
global_channels: int = -1,
|
||||
dropout_rate: float = 0.0,
|
||||
bias: bool = True,
|
||||
use_weight_norm: bool = True,
|
||||
use_first_conv: bool = False,
|
||||
use_last_conv: bool = False,
|
||||
scale_residual: bool = False,
|
||||
scale_skip_connect: bool = False,
|
||||
):
|
||||
"""Initialize WaveNet module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
kernel_size (int): Kernel size of dilated convolution.
|
||||
layers (int): Number of residual block layers.
|
||||
stacks (int): Number of stacks i.e., dilation cycles.
|
||||
base_dilation (int): Base dilation factor.
|
||||
residual_channels (int): Number of channels in residual conv.
|
||||
gate_channels (int): Number of channels in gated conv.
|
||||
skip_channels (int): Number of channels in skip conv.
|
||||
aux_channels (int): Number of channels for local conditioning feature.
|
||||
global_channels (int): Number of channels for global conditioning feature.
|
||||
dropout_rate (float): Dropout rate. 0.0 means no dropout applied.
|
||||
bias (bool): Whether to use bias parameter in conv layer.
|
||||
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
|
||||
be applied to all of the conv layers.
|
||||
use_first_conv (bool): Whether to use the first conv layers.
|
||||
use_last_conv (bool): Whether to use the last conv layers.
|
||||
scale_residual (bool): Whether to scale the residual outputs.
|
||||
scale_skip_connect (bool): Whether to scale the skip connection outputs.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.layers = layers
|
||||
self.stacks = stacks
|
||||
self.kernel_size = kernel_size
|
||||
self.base_dilation = base_dilation
|
||||
self.use_first_conv = use_first_conv
|
||||
self.use_last_conv = use_last_conv
|
||||
self.scale_skip_connect = scale_skip_connect
|
||||
|
||||
# check the number of layers and stacks
|
||||
assert layers % stacks == 0
|
||||
layers_per_stack = layers // stacks
|
||||
|
||||
# define first convolution
|
||||
if self.use_first_conv:
|
||||
self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
|
||||
|
||||
# define residual blocks
|
||||
self.conv_layers = torch.nn.ModuleList()
|
||||
for layer in range(layers):
|
||||
dilation = base_dilation ** (layer % layers_per_stack)
|
||||
conv = ResidualBlock(
|
||||
kernel_size=kernel_size,
|
||||
residual_channels=residual_channels,
|
||||
gate_channels=gate_channels,
|
||||
skip_channels=skip_channels,
|
||||
aux_channels=aux_channels,
|
||||
global_channels=global_channels,
|
||||
dilation=dilation,
|
||||
dropout_rate=dropout_rate,
|
||||
bias=bias,
|
||||
scale_residual=scale_residual,
|
||||
)
|
||||
self.conv_layers += [conv]
|
||||
|
||||
# define output layers
|
||||
if self.use_last_conv:
|
||||
self.last_conv = torch.nn.Sequential(
|
||||
torch.nn.ReLU(inplace=True),
|
||||
Conv1d1x1(skip_channels, skip_channels, bias=True),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
Conv1d1x1(skip_channels, out_channels, bias=True),
|
||||
)
|
||||
|
||||
# apply weight norm
|
||||
if use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: Optional[torch.Tensor] = None,
|
||||
c: Optional[torch.Tensor] = None,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input noise signal (B, 1, T) if use_first_conv else
|
||||
(B, residual_channels, T).
|
||||
x_mask (Optional[Tensor]): Mask tensor (B, 1, T).
|
||||
c (Optional[Tensor]): Local conditioning features (B, aux_channels, T).
|
||||
g (Optional[Tensor]): Global conditioning features (B, global_channels, 1).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, out_channels, T) if use_last_conv else
|
||||
(B, residual_channels, T).
|
||||
|
||||
"""
|
||||
# encode to hidden representation
|
||||
if self.use_first_conv:
|
||||
x = self.first_conv(x)
|
||||
|
||||
# residual block
|
||||
skips = 0.0
|
||||
for f in self.conv_layers:
|
||||
x, h = f(x, x_mask=x_mask, c=c, g=g)
|
||||
skips = skips + h
|
||||
x = skips
|
||||
if self.scale_skip_connect:
|
||||
x = x * math.sqrt(1.0 / len(self.conv_layers))
|
||||
|
||||
# apply final layers
|
||||
if self.use_last_conv:
|
||||
x = self.last_conv(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
"""Remove weight normalization module from all of the layers."""
|
||||
|
||||
def _remove_weight_norm(m: torch.nn.Module):
|
||||
try:
|
||||
logging.debug(f"Weight norm is removed from {m}.")
|
||||
torch.nn.utils.remove_weight_norm(m)
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_weight_norm)
|
||||
|
||||
def apply_weight_norm(self):
|
||||
"""Apply weight normalization module from all of the layers."""
|
||||
|
||||
def _apply_weight_norm(m: torch.nn.Module):
|
||||
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
||||
torch.nn.utils.weight_norm(m)
|
||||
logging.debug(f"Weight norm is applied to {m}.")
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
@staticmethod
|
||||
def _get_receptive_field_size(
|
||||
layers: int,
|
||||
stacks: int,
|
||||
kernel_size: int,
|
||||
base_dilation: int,
|
||||
) -> int:
|
||||
assert layers % stacks == 0
|
||||
layers_per_cycle = layers // stacks
|
||||
dilations = [base_dilation ** (i % layers_per_cycle) for i in range(layers)]
|
||||
return (kernel_size - 1) * sum(dilations) + 1
|
||||
|
||||
@property
|
||||
def receptive_field_size(self) -> int:
|
||||
"""Return receptive field size."""
|
||||
return self._get_receptive_field_size(
|
||||
self.layers, self.stacks, self.kernel_size, self.base_dilation
|
||||
)
|
||||
|
||||
|
||||
class Conv1d(torch.nn.Conv1d):
|
||||
"""Conv1d module with customized initialization."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize Conv1d module."""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def reset_parameters(self):
|
||||
"""Reset parameters."""
|
||||
torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
|
||||
if self.bias is not None:
|
||||
torch.nn.init.constant_(self.bias, 0.0)
|
||||
|
||||
|
||||
class Conv1d1x1(Conv1d):
|
||||
"""1x1 Conv1d with customized initialization."""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, bias: bool):
|
||||
"""Initialize 1x1 Conv1d module."""
|
||||
super().__init__(
|
||||
in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias
|
||||
)
|
||||
|
||||
|
||||
class ResidualBlock(torch.nn.Module):
|
||||
"""Residual block module in WaveNet."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: int = 3,
|
||||
residual_channels: int = 64,
|
||||
gate_channels: int = 128,
|
||||
skip_channels: int = 64,
|
||||
aux_channels: int = 80,
|
||||
global_channels: int = -1,
|
||||
dropout_rate: float = 0.0,
|
||||
dilation: int = 1,
|
||||
bias: bool = True,
|
||||
scale_residual: bool = False,
|
||||
):
|
||||
"""Initialize ResidualBlock module.
|
||||
|
||||
Args:
|
||||
kernel_size (int): Kernel size of dilation convolution layer.
|
||||
residual_channels (int): Number of channels for residual connection.
|
||||
skip_channels (int): Number of channels for skip connection.
|
||||
aux_channels (int): Number of local conditioning channels.
|
||||
dropout (float): Dropout probability.
|
||||
dilation (int): Dilation factor.
|
||||
bias (bool): Whether to add bias parameter in convolution layers.
|
||||
scale_residual (bool): Whether to scale the residual outputs.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.dropout_rate = dropout_rate
|
||||
self.residual_channels = residual_channels
|
||||
self.skip_channels = skip_channels
|
||||
self.scale_residual = scale_residual
|
||||
|
||||
# check
|
||||
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
||||
assert gate_channels % 2 == 0
|
||||
|
||||
# dilation conv
|
||||
padding = (kernel_size - 1) // 2 * dilation
|
||||
self.conv = Conv1d(
|
||||
residual_channels,
|
||||
gate_channels,
|
||||
kernel_size,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
# local conditioning
|
||||
if aux_channels > 0:
|
||||
self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
|
||||
else:
|
||||
self.conv1x1_aux = None
|
||||
|
||||
# global conditioning
|
||||
if global_channels > 0:
|
||||
self.conv1x1_glo = Conv1d1x1(global_channels, gate_channels, bias=False)
|
||||
else:
|
||||
self.conv1x1_glo = None
|
||||
|
||||
# conv output is split into two groups
|
||||
gate_out_channels = gate_channels // 2
|
||||
|
||||
# NOTE(kan-bayashi): concat two convs into a single conv for the efficiency
|
||||
# (integrate res 1x1 + skip 1x1 convs)
|
||||
self.conv1x1_out = Conv1d1x1(
|
||||
gate_out_channels, residual_channels + skip_channels, bias=bias
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: Optional[torch.Tensor] = None,
|
||||
c: Optional[torch.Tensor] = None,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, residual_channels, T).
|
||||
x_mask Optional[torch.Tensor]: Mask tensor (B, 1, T).
|
||||
c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T).
|
||||
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor for residual connection (B, residual_channels, T).
|
||||
Tensor: Output tensor for skip connection (B, skip_channels, T).
|
||||
|
||||
"""
|
||||
residual = x
|
||||
x = F.dropout(x, p=self.dropout_rate, training=self.training)
|
||||
x = self.conv(x)
|
||||
|
||||
# split into two part for gated activation
|
||||
splitdim = 1
|
||||
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
|
||||
|
||||
# local conditioning
|
||||
if c is not None:
|
||||
c = self.conv1x1_aux(c)
|
||||
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
|
||||
xa, xb = xa + ca, xb + cb
|
||||
|
||||
# global conditioning
|
||||
if g is not None:
|
||||
g = self.conv1x1_glo(g)
|
||||
ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim)
|
||||
xa, xb = xa + ga, xb + gb
|
||||
|
||||
x = torch.tanh(xa) * torch.sigmoid(xb)
|
||||
|
||||
# residual + skip 1x1 conv
|
||||
x = self.conv1x1_out(x)
|
||||
if x_mask is not None:
|
||||
x = x * x_mask
|
||||
|
||||
# split integrated conv results
|
||||
x, s = x.split([self.residual_channels, self.skip_channels], dim=1)
|
||||
|
||||
# for residual connection
|
||||
x = x + residual
|
||||
if self.scale_residual:
|
||||
x = x * math.sqrt(0.5)
|
||||
|
||||
return x, s
|
Loading…
x
Reference in New Issue
Block a user