mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 07:34:21 +00:00
modify training script, clean codes
This commit is contained in:
parent
8d09f8e6bf
commit
04c6ecbaa1
@ -1,80 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script split the LJSpeech dataset cuts into three sets:
|
||||
- training, 12500
|
||||
- validation, 100
|
||||
- test, 500
|
||||
The numbers are from https://arxiv.org/pdf/2106.06103.pdf
|
||||
|
||||
Usage example:
|
||||
python3 ./local/split_subsets.py ./data/spectrogram
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse import load_manifest_lazy
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"manifest_dir",
|
||||
type=Path,
|
||||
default=Path("data/spectrogram"),
|
||||
help="Path to the manifest file",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
manifest_dir = Path(args.manifest_dir)
|
||||
prefix = "ljspeech"
|
||||
suffix = "jsonl.gz"
|
||||
# all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all.{suffix}")
|
||||
all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all_phonemized.{suffix}")
|
||||
|
||||
cut_ids = list(all_cuts.ids)
|
||||
random.shuffle(cut_ids)
|
||||
|
||||
train_cuts = all_cuts.subset(cut_ids=cut_ids[:12500])
|
||||
valid_cuts = all_cuts.subset(cut_ids=cut_ids[12500:12500 + 100])
|
||||
test_cuts = all_cuts.subset(cut_ids=cut_ids[12500 + 100:])
|
||||
assert len(train_cuts) == 12500, "expected 12500 cuts for training but got len(train_cuts)"
|
||||
assert len(valid_cuts) == 100, "expected 100 cuts but for validation but got len(valid_cuts)"
|
||||
assert len(test_cuts) == 500, "expected 500 cuts for test but got len(test_cuts)"
|
||||
|
||||
train_cuts.to_file(manifest_dir / f"{prefix}_cuts_train.{suffix}")
|
||||
valid_cuts.to_file(manifest_dir / f"{prefix}_cuts_valid.{suffix}")
|
||||
test_cuts.to_file(manifest_dir / f"{prefix}_cuts_test.{suffix}")
|
||||
|
||||
logging.info("Splitted into three sets: training (12500), validation (100), and test (500)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
@ -9,8 +9,7 @@ nj=1
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
# dl_dir=$PWD/download
|
||||
dl_dir=/star-data/zengwei/download/ljspeech/
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
@ -66,22 +65,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
# if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
# log "Stage 3: Phonemize the transcripts for LJSpeech"
|
||||
# if [ ! -e data/spectrogram/.ljspeech_phonemized.done ]; then
|
||||
# ./local/phonemize_text.py data/spectrogram
|
||||
# touch data/spectrogram/.ljspeech_phonemized.done
|
||||
# fi
|
||||
# fi
|
||||
|
||||
# if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
# log "Stage 4: Split the LJSpeech cuts into three sets"
|
||||
# if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
|
||||
# ./local/split_subsets.py data/spectrogram
|
||||
# touch data/spectrogram/.ljspeech_split.done
|
||||
# fi
|
||||
# fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Split the LJSpeech cuts into train, valid and test sets"
|
||||
if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
|
||||
@ -94,6 +77,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
lhotse subset --last 500 \
|
||||
data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \
|
||||
data/spectrogram/ljspeech_cuts_test.jsonl.gz
|
||||
|
||||
rm data/spectrogram/ljspeech_cuts_validtest.jsonl.gz
|
||||
|
||||
n=$(( $(gunzip -c data/spectrogram/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 ))
|
||||
|
@ -1,161 +0,0 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size*dilation - dilation)/2)
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def intersperse(lst, item):
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
result[1::2] = lst
|
||||
return result
|
||||
|
||||
|
||||
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
||||
"""KL(P||Q)"""
|
||||
kl = (logs_q - logs_p) - 0.5
|
||||
kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
|
||||
return kl
|
||||
|
||||
|
||||
def rand_gumbel(shape):
|
||||
"""Sample from the Gumbel distribution, protect from overflows."""
|
||||
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
||||
return -torch.log(-torch.log(uniform_samples))
|
||||
|
||||
|
||||
def rand_gumbel_like(x):
|
||||
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
||||
return g
|
||||
|
||||
|
||||
def slice_segments(x, ids_str, segment_size=4):
|
||||
ret = torch.zeros_like(x[:, :, :segment_size])
|
||||
for i in range(x.size(0)):
|
||||
idx_str = ids_str[i]
|
||||
idx_end = idx_str + segment_size
|
||||
ret[i] = x[i, :, idx_str:idx_end]
|
||||
return ret
|
||||
|
||||
|
||||
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
||||
b, d, t = x.size()
|
||||
if x_lengths is None:
|
||||
x_lengths = t
|
||||
ids_str_max = x_lengths - segment_size + 1
|
||||
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
||||
ret = slice_segments(x, ids_str, segment_size)
|
||||
return ret, ids_str
|
||||
|
||||
|
||||
def get_timing_signal_1d(
|
||||
length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
||||
position = torch.arange(length, dtype=torch.float)
|
||||
num_timescales = channels // 2
|
||||
log_timescale_increment = (
|
||||
math.log(float(max_timescale) / float(min_timescale)) /
|
||||
(num_timescales - 1))
|
||||
inv_timescales = min_timescale * torch.exp(
|
||||
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
|
||||
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
||||
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
||||
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
||||
signal = signal.view(1, channels, length)
|
||||
return signal
|
||||
|
||||
|
||||
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
||||
b, channels, length = x.size()
|
||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||
return x + signal.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
||||
b, channels, length = x.size()
|
||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
||||
|
||||
|
||||
def subsequent_mask(length):
|
||||
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
||||
return mask
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||
n_channels_int = n_channels[0]
|
||||
in_act = input_a + input_b
|
||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def shift_1d(x):
|
||||
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
||||
return x
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = length.max()
|
||||
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
"""
|
||||
duration: [b, 1, t_x]
|
||||
mask: [b, 1, t_y, t_x]
|
||||
"""
|
||||
device = duration.device
|
||||
|
||||
b, _, t_y, t_x = mask.shape
|
||||
cum_duration = torch.cumsum(duration, -1)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
path = path.unsqueeze(1).transpose(2,3) * mask
|
||||
return path
|
||||
|
||||
|
||||
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
||||
norm_type = float(norm_type)
|
||||
if clip_value is not None:
|
||||
clip_value = float(clip_value)
|
||||
|
||||
total_norm = 0
|
||||
for p in parameters:
|
||||
param_norm = p.grad.data.norm(norm_type)
|
||||
total_norm += param_norm.item() ** norm_type
|
||||
if clip_value is not None:
|
||||
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
||||
total_norm = total_norm ** (1. / norm_type)
|
||||
return total_norm
|
@ -1,4 +1,4 @@
|
||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
@ -1,416 +0,0 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
# From https://github.com/espnet/espnet/blob/master/espnet2/layers/stft.py
|
||||
class Stft(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_fft: int = 512,
|
||||
win_length: int = None,
|
||||
hop_length: int = 128,
|
||||
window: Optional[str] = "hann",
|
||||
center: bool = True,
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
if win_length is None:
|
||||
self.win_length = n_fft
|
||||
else:
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.center = center
|
||||
self.normalized = normalized
|
||||
self.onesided = onesided
|
||||
if window is not None and not hasattr(torch, f"{window}_window"):
|
||||
raise ValueError(f"{window} window is not implemented")
|
||||
self.window = window
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
f"n_fft={self.n_fft}, "
|
||||
f"win_length={self.win_length}, "
|
||||
f"hop_length={self.hop_length}, "
|
||||
f"center={self.center}, "
|
||||
f"normalized={self.normalized}, "
|
||||
f"onesided={self.onesided}"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, ilens: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""STFT forward function.
|
||||
|
||||
Args:
|
||||
input: (Batch, Nsamples) or (Batch, Nsample, Channels)
|
||||
ilens: (Batch)
|
||||
Returns:
|
||||
output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
|
||||
|
||||
"""
|
||||
bs = input.size(0)
|
||||
if input.dim() == 3:
|
||||
multi_channel = True
|
||||
# input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
|
||||
input = input.transpose(1, 2).reshape(-1, input.size(1))
|
||||
else:
|
||||
multi_channel = False
|
||||
|
||||
# NOTE(kamo):
|
||||
# The default behaviour of torch.stft is compatible with librosa.stft
|
||||
# about padding and scaling.
|
||||
# Note that it's different from scipy.signal.stft
|
||||
|
||||
# output: (Batch, Freq, Frames, 2=real_imag)
|
||||
# or (Batch, Channel, Freq, Frames, 2=real_imag)
|
||||
if self.window is not None:
|
||||
window_func = getattr(torch, f"{self.window}_window")
|
||||
window = window_func(
|
||||
self.win_length, dtype=input.dtype, device=input.device
|
||||
)
|
||||
else:
|
||||
window = None
|
||||
|
||||
# For the compatibility of ARM devices, which do not support
|
||||
# torch.stft() due to the lack of MKL (on older pytorch versions),
|
||||
# there is an alternative replacement implementation with librosa.
|
||||
# Note: pytorch >= 1.10.0 now has native support for FFT and STFT
|
||||
# on all cpu targets including ARM.
|
||||
if input.is_cuda or torch.backends.mkl.is_available():
|
||||
stft_kwargs = dict(
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length,
|
||||
center=self.center,
|
||||
window=window,
|
||||
normalized=self.normalized,
|
||||
onesided=self.onesided,
|
||||
)
|
||||
stft_kwargs["return_complex"] = True
|
||||
output = torch.stft(input, **stft_kwargs)
|
||||
output = torch.view_as_real(output)
|
||||
else:
|
||||
if self.training:
|
||||
raise NotImplementedError(
|
||||
"stft is implemented with librosa on this device, which does not "
|
||||
"support the training mode."
|
||||
)
|
||||
|
||||
# use stft_kwargs to flexibly control different PyTorch versions' kwargs
|
||||
# note: librosa does not support a win_length that is < n_ftt
|
||||
# but the window can be manually padded (see below).
|
||||
stft_kwargs = dict(
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
center=self.center,
|
||||
window=window,
|
||||
pad_mode="reflect",
|
||||
)
|
||||
|
||||
if window is not None:
|
||||
# pad the given window to n_fft
|
||||
n_pad_left = (self.n_fft - window.shape[0]) // 2
|
||||
n_pad_right = self.n_fft - window.shape[0] - n_pad_left
|
||||
stft_kwargs["window"] = torch.cat(
|
||||
[torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0
|
||||
).numpy()
|
||||
else:
|
||||
win_length = (
|
||||
self.win_length if self.win_length is not None else self.n_fft
|
||||
)
|
||||
stft_kwargs["window"] = torch.ones(win_length)
|
||||
|
||||
output = []
|
||||
# iterate over istances in a batch
|
||||
for i, instance in enumerate(input):
|
||||
stft = librosa.stft(input[i].numpy(), **stft_kwargs)
|
||||
output.append(torch.tensor(np.stack([stft.real, stft.imag], -1)))
|
||||
output = torch.stack(output, 0)
|
||||
if not self.onesided:
|
||||
len_conj = self.n_fft - output.shape[1]
|
||||
conj = output[:, 1 : 1 + len_conj].flip(1)
|
||||
conj[:, :, :, -1].data *= -1
|
||||
output = torch.cat([output, conj], 1)
|
||||
if self.normalized:
|
||||
output = output * (stft_kwargs["window"].shape[0] ** (-0.5))
|
||||
|
||||
# output: (Batch, Freq, Frames, 2=real_imag)
|
||||
# -> (Batch, Frames, Freq, 2=real_imag)
|
||||
output = output.transpose(1, 2)
|
||||
if multi_channel:
|
||||
# output: (Batch * Channel, Frames, Freq, 2=real_imag)
|
||||
# -> (Batch, Frame, Channel, Freq, 2=real_imag)
|
||||
output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
|
||||
1, 2
|
||||
)
|
||||
|
||||
if ilens is not None:
|
||||
if self.center:
|
||||
pad = self.n_fft // 2
|
||||
ilens = ilens + 2 * pad
|
||||
|
||||
olens = (
|
||||
torch.div(ilens - self.n_fft, self.hop_length, rounding_mode="trunc")
|
||||
+ 1
|
||||
)
|
||||
output.masked_fill_(make_pad_mask(olens), 0.0)
|
||||
else:
|
||||
olens = None
|
||||
|
||||
return output, olens
|
||||
|
||||
|
||||
# From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/linear_spectrogram.py
|
||||
class LinearSpectrogram(nn.Module):
|
||||
"""Linear amplitude spectrogram.
|
||||
|
||||
Stft -> amplitude-spec
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_fft: int = 1024,
|
||||
win_length: int = None,
|
||||
hop_length: int = 256,
|
||||
window: Optional[str] = "hann",
|
||||
center: bool = True,
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
self.stft = Stft(
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
window=window,
|
||||
center=center,
|
||||
normalized=normalized,
|
||||
onesided=onesided,
|
||||
)
|
||||
self.n_fft = n_fft
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_fft // 2 + 1
|
||||
|
||||
def get_parameters(self) -> Dict[str, Any]:
|
||||
"""Return the parameters required by Vocoder."""
|
||||
return dict(
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Stft: time -> time-freq
|
||||
input_stft, feats_lens = self.stft(input, input_lengths)
|
||||
|
||||
assert input_stft.dim() >= 4, input_stft.shape
|
||||
# "2" refers to the real/imag parts of Complex
|
||||
assert input_stft.shape[-1] == 2, input_stft.shape
|
||||
|
||||
# STFT -> Power spectrum -> Amp spectrum
|
||||
# input_stft: (..., F, 2) -> (..., F)
|
||||
input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
|
||||
input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10))
|
||||
return input_amp, feats_lens
|
||||
|
||||
|
||||
# From https://github.com/espnet/espnet/blob/master/espnet2/layers/log_mel.py
|
||||
class LogMel(nn.Module):
|
||||
"""Convert STFT to fbank feats
|
||||
|
||||
The arguments is same as librosa.filters.mel
|
||||
|
||||
Args:
|
||||
fs: number > 0 [scalar] sampling rate of the incoming signal
|
||||
n_fft: int > 0 [scalar] number of FFT components
|
||||
n_mels: int > 0 [scalar] number of Mel bands to generate
|
||||
fmin: float >= 0 [scalar] lowest frequency (in Hz)
|
||||
fmax: float >= 0 [scalar] highest frequency (in Hz).
|
||||
If `None`, use `fmax = fs / 2.0`
|
||||
htk: use HTK formula instead of Slaney
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: int = 16000,
|
||||
n_fft: int = 512,
|
||||
n_mels: int = 80,
|
||||
fmin: float = None,
|
||||
fmax: float = None,
|
||||
htk: bool = False,
|
||||
log_base: float = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
fmin = 0 if fmin is None else fmin
|
||||
fmax = fs / 2 if fmax is None else fmax
|
||||
_mel_options = dict(
|
||||
sr=fs,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
htk=htk,
|
||||
)
|
||||
self.mel_options = _mel_options
|
||||
self.log_base = log_base
|
||||
|
||||
# Note(kamo): The mel matrix of librosa is different from kaldi.
|
||||
melmat = librosa.filters.mel(**_mel_options)
|
||||
# melmat: (D2, D1) -> (D1, D2)
|
||||
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
|
||||
|
||||
def extra_repr(self):
|
||||
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
feat: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
|
||||
mel_feat = torch.matmul(feat, self.melmat)
|
||||
mel_feat = torch.clamp(mel_feat, min=1e-10)
|
||||
|
||||
if self.log_base is None:
|
||||
logmel_feat = mel_feat.log()
|
||||
elif self.log_base == 2.0:
|
||||
logmel_feat = mel_feat.log2()
|
||||
elif self.log_base == 10.0:
|
||||
logmel_feat = mel_feat.log10()
|
||||
else:
|
||||
logmel_feat = mel_feat.log() / torch.log(self.log_base)
|
||||
|
||||
# Zero padding
|
||||
if ilens is not None:
|
||||
logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens), 0.0)
|
||||
else:
|
||||
ilens = feat.new_full(
|
||||
[feat.size(0)], fill_value=feat.size(1), dtype=torch.long
|
||||
)
|
||||
return logmel_feat, ilens
|
||||
|
||||
|
||||
# From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/log_mel_fbank.py
|
||||
class LogMelFbank(nn.Module):
|
||||
"""Conventional frontend structure for TTS.
|
||||
|
||||
Stft -> amplitude-spec -> Log-Mel-Fbank
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: int = 16000,
|
||||
n_fft: int = 1024,
|
||||
win_length: int = None,
|
||||
hop_length: int = 256,
|
||||
window: Optional[str] = "hann",
|
||||
center: bool = True,
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
n_mels: int = 80,
|
||||
fmin: Optional[int] = 80,
|
||||
fmax: Optional[int] = 7600,
|
||||
htk: bool = False,
|
||||
log_base: Optional[float] = 10.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
|
||||
self.stft = Stft(
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
window=window,
|
||||
center=center,
|
||||
normalized=normalized,
|
||||
onesided=onesided,
|
||||
)
|
||||
|
||||
self.logmel = LogMel(
|
||||
fs=fs,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
htk=htk,
|
||||
log_base=log_base,
|
||||
)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels
|
||||
|
||||
def get_parameters(self) -> Dict[str, Any]:
|
||||
"""Return the parameters required by Vocoder"""
|
||||
return dict(
|
||||
fs=self.fs,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.hop_length,
|
||||
window=self.window,
|
||||
n_mels=self.n_mels,
|
||||
win_length=self.win_length,
|
||||
fmin=self.fmin,
|
||||
fmax=self.fmax,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Domain-conversion: e.g. Stft: time -> time-freq
|
||||
input_stft, feats_lens = self.stft(input, input_lengths)
|
||||
|
||||
assert input_stft.dim() >= 4, input_stft.shape
|
||||
# "2" refers to the real/imag parts of Complex
|
||||
assert input_stft.shape[-1] == 2, input_stft.shape
|
||||
|
||||
# NOTE(kamo): We use different definition for log-spec between TTS and ASR
|
||||
# TTS: log_10(abs(stft))
|
||||
# ASR: log_e(power(stft))
|
||||
|
||||
# input_stft: (..., F, 2) -> (..., F)
|
||||
input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
|
||||
input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10))
|
||||
input_feats, _ = self.logmel(input_amp, feats_lens)
|
||||
return input_feats, feats_lens
|
@ -1,4 +1,5 @@
|
||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py
|
||||
# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
@ -1,7 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -17,118 +16,34 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script performs model inference on test set.
|
||||
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
./vits/infer.py \
|
||||
--epoch 1000 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
--max-duration 500
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
|
||||
from train import get_model, get_params, prepare_input
|
||||
from train import get_model, get_params
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
make_pad_mask,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.utils import AttributeDict, setup_logger
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -138,35 +53,16 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
default=1000,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
default="vits/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
@ -174,7 +70,7 @@ def get_parser():
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/tokens.txt",
|
||||
help="""Path to tokens.txt.""",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
return parser
|
||||
@ -185,8 +81,9 @@ def infer_dataset(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
tokenizer: Tokenizer,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
) -> None:
|
||||
"""Decode dataset.
|
||||
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
@ -195,20 +92,8 @@ def infer_dataset(
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
tokenizer:
|
||||
Used to convert text to phonemes.
|
||||
"""
|
||||
# Background worker save audios to disk.
|
||||
def _save_worker(
|
||||
@ -233,7 +118,7 @@ def infer_dataset(
|
||||
|
||||
device = next(model.parameters()).device
|
||||
num_cuts = 0
|
||||
log_interval = 10
|
||||
log_interval = 5
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
@ -242,7 +127,6 @@ def infer_dataset(
|
||||
|
||||
futures = []
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
# We only want one background worker so that serialization is deterministic.
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
batch_size = len(batch["text"])
|
||||
|
||||
@ -253,7 +137,7 @@ def infer_dataset(
|
||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||
tokens = tokens.to(device)
|
||||
tokens_lens = tokens_lens.to(device)
|
||||
# a tensor of shape (B, T)
|
||||
# tensor of shape (B, T)
|
||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
||||
|
||||
audio = batch["audio"]
|
||||
@ -265,9 +149,6 @@ def infer_dataset(
|
||||
# convert to samples
|
||||
audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
|
||||
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
|
||||
futures.append(
|
||||
executor.submit(
|
||||
_save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred
|
||||
@ -295,10 +176,7 @@ def main():
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
params.suffix = f"epoch-{params.epoch}"
|
||||
|
||||
params.res_dir = params.exp_dir / "infer" / params.suffix
|
||||
params.save_wav_dir = params.res_dir / "wav"
|
||||
@ -322,40 +200,16 @@ def main():
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
num_param_g = sum([p.numel() for p in model.generator.parameters()])
|
||||
logging.info(f"Number of parameters in generator: {num_param_g}")
|
||||
num_param_d = sum([p.numel() for p in model.discriminator.parameters()])
|
||||
logging.info(f"Number of parameters in discriminator: {num_param_d}")
|
||||
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
@ -371,17 +225,8 @@ def main():
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# save_results(
|
||||
# params=params,
|
||||
# test_set_name=test_set,
|
||||
# results_dict=results_dict,
|
||||
# )
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
# torch.set_num_threads(1)
|
||||
# torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -1,4 +1,4 @@
|
||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
@ -9,7 +9,7 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
||||
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributions as D
|
||||
@ -266,7 +266,7 @@ class MelSpectrogramLoss(torch.nn.Module):
|
||||
return mel_loss
|
||||
|
||||
|
||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py
|
||||
|
||||
"""VITS-related loss modules.
|
||||
|
||||
|
@ -1,534 +0,0 @@
|
||||
import copy
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
import commons
|
||||
import modules
|
||||
import attentions
|
||||
import monotonic_align
|
||||
|
||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from commons import init_weights, get_padding
|
||||
|
||||
|
||||
class StochasticDurationPredictor(nn.Module):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
|
||||
super().__init__()
|
||||
filter_channels = in_channels # it needs to be removed from future version.
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.n_flows = n_flows
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.log_flow = modules.Log()
|
||||
self.flows = nn.ModuleList()
|
||||
self.flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(n_flows):
|
||||
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
||||
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
self.post_flows = nn.ModuleList()
|
||||
self.post_flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(4):
|
||||
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.post_flows.append(modules.Flip())
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
||||
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
|
||||
x = torch.detach(x)
|
||||
x = self.pre(x)
|
||||
if g is not None:
|
||||
g = torch.detach(g)
|
||||
x = x + self.cond(g)
|
||||
x = self.convs(x, x_mask)
|
||||
x = self.proj(x) * x_mask
|
||||
|
||||
if not reverse:
|
||||
flows = self.flows
|
||||
assert w is not None
|
||||
|
||||
logdet_tot_q = 0
|
||||
h_w = self.post_pre(w)
|
||||
h_w = self.post_convs(h_w, x_mask)
|
||||
h_w = self.post_proj(h_w) * x_mask
|
||||
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||
z_q = e_q
|
||||
for flow in self.post_flows:
|
||||
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
||||
logdet_tot_q += logdet_q
|
||||
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
||||
u = torch.sigmoid(z_u) * x_mask
|
||||
z0 = (w - u) * x_mask
|
||||
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
|
||||
logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
|
||||
|
||||
logdet_tot = 0
|
||||
z0, logdet = self.log_flow(z0, x_mask)
|
||||
logdet_tot += logdet
|
||||
z = torch.cat([z0, z1], 1)
|
||||
for flow in flows:
|
||||
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
||||
logdet_tot = logdet_tot + logdet
|
||||
nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
|
||||
return nll + logq # [b]
|
||||
else:
|
||||
flows = list(reversed(self.flows))
|
||||
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
||||
for flow in flows:
|
||||
z = flow(z, x_mask, g=x, reverse=reverse)
|
||||
z0, z1 = torch.split(z, [1, 1], 1)
|
||||
logw = z0
|
||||
return logw
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
|
||||
self.norm_1 = modules.LayerNorm(filter_channels)
|
||||
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
|
||||
self.norm_2 = modules.LayerNorm(filter_channels)
|
||||
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
||||
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
x = torch.detach(x)
|
||||
if g is not None:
|
||||
g = torch.detach(g)
|
||||
x = x + self.cond(g)
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_1(x)
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_2(x)
|
||||
x = self.drop(x)
|
||||
x = self.proj(x * x_mask)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
n_vocab,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout):
|
||||
super().__init__()
|
||||
self.n_vocab = n_vocab
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
||||
|
||||
self.encoder = attentions.Encoder(
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout)
|
||||
self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths):
|
||||
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
||||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
|
||||
x = self.encoder(x * x_mask, x_mask)
|
||||
stats = self.proj(x) * x_mask
|
||||
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
return x, m, logs, x_mask
|
||||
|
||||
|
||||
class ResidualCouplingBlock(nn.Module):
|
||||
def __init__(self,
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
n_flows=4,
|
||||
gin_channels=0):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.n_flows = n_flows
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.flows = nn.ModuleList()
|
||||
for i in range(n_flows):
|
||||
self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
if not reverse:
|
||||
for flow in self.flows:
|
||||
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
||||
else:
|
||||
for flow in reversed(self.flows):
|
||||
x = flow(x, x_mask, g=g, reverse=reverse)
|
||||
return x
|
||||
|
||||
|
||||
class PosteriorEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=0):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
||||
return z, m, logs, x_mask
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
|
||||
super(Generator, self).__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(weight_norm(
|
||||
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
||||
k, u, padding=(k-u)//2)))
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel//(2**(i+1))
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||
self.ups.apply(init_weights)
|
||||
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
|
||||
def forward(self, x, g=None):
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i*self.num_kernels+j](x)
|
||||
else:
|
||||
xs += self.resblocks[i*self.num_kernels+j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super(DiscriminatorP, self).__init__()
|
||||
self.period = period
|
||||
self.use_spectral_norm = use_spectral_norm
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
|
||||
])
|
||||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(DiscriminatorS, self).__init__()
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||||
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
])
|
||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
periods = [2,3,5,7,11]
|
||||
|
||||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
||||
self.discriminators = nn.ModuleList(discs)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_rs.append(fmap_r)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
|
||||
class SynthesizerTrn(nn.Module):
|
||||
"""
|
||||
Synthesizer for Training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_vocab,
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
n_speakers=0,
|
||||
gin_channels=0,
|
||||
use_sdp=True,
|
||||
**kwargs):
|
||||
|
||||
super().__init__()
|
||||
self.n_vocab = n_vocab
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.resblock = resblock
|
||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||
self.upsample_rates = upsample_rates
|
||||
self.upsample_initial_channel = upsample_initial_channel
|
||||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
||||
self.segment_size = segment_size
|
||||
self.n_speakers = n_speakers
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.use_sdp = use_sdp
|
||||
|
||||
self.enc_p = TextEncoder(n_vocab,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout)
|
||||
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
||||
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
|
||||
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||
|
||||
if use_sdp:
|
||||
self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
|
||||
else:
|
||||
self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
|
||||
|
||||
if n_speakers > 1:
|
||||
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
||||
|
||||
def forward(self, x, x_lengths, y, y_lengths, sid=None):
|
||||
|
||||
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
|
||||
if self.n_speakers > 0:
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
else:
|
||||
g = None
|
||||
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
||||
z_p = self.flow(z, y_mask, g=g)
|
||||
|
||||
with torch.no_grad():
|
||||
# negative cross-entropy
|
||||
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
|
||||
neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s]
|
||||
neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
|
||||
neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
|
||||
neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
|
||||
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
|
||||
|
||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||
attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
|
||||
w = attn.sum(2)
|
||||
if self.use_sdp:
|
||||
l_length = self.dp(x, x_mask, w, g=g)
|
||||
l_length = l_length / torch.sum(x_mask)
|
||||
else:
|
||||
logw_ = torch.log(w + 1e-6) * x_mask
|
||||
logw = self.dp(x, x_mask, g=g)
|
||||
l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging
|
||||
|
||||
# expand prior
|
||||
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
|
||||
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
|
||||
o = self.dec(z_slice, g=g)
|
||||
return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||||
|
||||
def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
|
||||
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
|
||||
if self.n_speakers > 0:
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
else:
|
||||
g = None
|
||||
|
||||
if self.use_sdp:
|
||||
logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
|
||||
else:
|
||||
logw = self.dp(x, x_mask, g=g)
|
||||
w = torch.exp(logw) * x_mask * length_scale
|
||||
w_ceil = torch.ceil(w)
|
||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||
attn = commons.generate_path(w_ceil, attn_mask)
|
||||
|
||||
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||
o = self.dec((z * y_mask)[:,:,:max_len], g=g)
|
||||
return o, attn, y_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
|
||||
assert self.n_speakers > 0, "n_speakers have to be larger than 0."
|
||||
g_src = self.emb_g(sid_src).unsqueeze(-1)
|
||||
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
|
||||
z_p = self.flow(z, y_mask, g=g_src)
|
||||
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
||||
o_hat = self.dec(z_hat * y_mask, g=g_tgt)
|
||||
return o_hat, y_mask, (z, z_p, z_hat)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
@ -1,17 +0,0 @@
|
||||
# https://github.com/jaywalnut310/vits/blob/main/text/symbols.py
|
||||
""" from https://github.com/keithito/tacotron """
|
||||
|
||||
'''
|
||||
Defines the set of symbols used in text input to the model.
|
||||
'''
|
||||
_pad = '_'
|
||||
_punctuation = ';:,.!?¡¿—…"«»“” '
|
||||
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
||||
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||
|
||||
|
||||
# Export all symbols:
|
||||
symbol_table = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
||||
|
||||
# Special symbol ids
|
||||
SPACE_ID = symbol_table.index(" ")
|
@ -20,6 +20,7 @@
|
||||
This code is based on
|
||||
- https://github.com/jaywalnut310/vits
|
||||
- https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py
|
||||
- https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/conformer.py
|
||||
"""
|
||||
|
||||
import copy
|
||||
@ -67,6 +68,7 @@ class TextEncoder(torch.nn.Module):
|
||||
self.emb = torch.nn.Embedding(vocabs, d_model)
|
||||
torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5)
|
||||
|
||||
# We use conformer as text encoder
|
||||
self.encoder = Transformer(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
|
@ -18,7 +18,6 @@ from typing import Dict, List
|
||||
|
||||
import g2p_en
|
||||
import tacotron_cleaner.cleaners
|
||||
|
||||
from utils import intersperse
|
||||
|
||||
|
||||
|
@ -1,9 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang,
|
||||
# Mingshuang Luo,
|
||||
# Zengwei Yao,
|
||||
# Daniel Povey)
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -22,9 +18,10 @@
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import torch
|
||||
@ -36,26 +33,17 @@ from torch.optim import Optimizer
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
|
||||
from tokenizer import Tokenizer
|
||||
from utils import (
|
||||
MetricsTracker,
|
||||
plot_feature,
|
||||
save_checkpoint,
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
)
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
from utils import MetricsTracker, plot_feature, save_checkpoint
|
||||
from vits import VITS
|
||||
|
||||
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
||||
@ -90,7 +78,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=30,
|
||||
default=1000,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
@ -104,15 +92,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-batch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --start-epoch is ignored and
|
||||
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
@ -127,7 +106,7 @@ def get_parser():
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/tokens.txt",
|
||||
help="""Path to tokens.txt.""",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -158,24 +137,11 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--save-every-n",
|
||||
type=int,
|
||||
default=4000,
|
||||
help="""Save checkpoint after processing this number of batches"
|
||||
default=20,
|
||||
help="""Save checkpoint after processing this number of epochs"
|
||||
periodically. We save checkpoint to exp-dir/ whenever
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keep-last-k",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""Only keep this number of checkpoints on disk.
|
||||
For instance, if it is 3, there are only 3 checkpoints
|
||||
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
||||
It does not affect checkpoints with name `epoch-xxx.pt`.
|
||||
params.cur_epoch % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'
|
||||
""",
|
||||
)
|
||||
|
||||
@ -218,8 +184,6 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||
|
||||
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
||||
|
||||
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||
|
||||
- feature_dim: The model input dim. It has to match the one used
|
||||
@ -242,18 +206,14 @@ def get_params() -> AttributeDict:
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": -1, # 0
|
||||
"log_interval": 10,
|
||||
"draw_interval": 500,
|
||||
# "reset_interval": 200,
|
||||
"log_interval": 50,
|
||||
"valid_interval": 200,
|
||||
"env_info": get_env_info(),
|
||||
"sampling_rate": 22050,
|
||||
"frame_shift": 256,
|
||||
"frame_length": 1024,
|
||||
"feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
|
||||
"mel_loss_params": {
|
||||
"n_mels": 80,
|
||||
},
|
||||
"n_mels": 80,
|
||||
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
|
||||
"lambda_mel": 45.0, # loss scaling coefficient for Mel loss
|
||||
"lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss
|
||||
@ -270,9 +230,7 @@ def load_checkpoint_if_available(
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_batch is positive, it will load the checkpoint from
|
||||
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
||||
params.start_epoch is larger than 1, it will load the checkpoint from
|
||||
If params.start_epoch is larger than 1, it will load the checkpoint from
|
||||
`params.start_epoch - 1`.
|
||||
|
||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||
@ -287,9 +245,7 @@ def load_checkpoint_if_available(
|
||||
Returns:
|
||||
Return a dict containing previously saved training info.
|
||||
"""
|
||||
if params.start_batch > 0:
|
||||
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
||||
elif params.start_epoch > 1:
|
||||
if params.start_epoch > 1:
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
else:
|
||||
return None
|
||||
@ -308,19 +264,15 @@ def load_checkpoint_if_available(
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
if params.start_batch > 0:
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
def get_model(params: AttributeDict) -> nn.Module:
|
||||
mel_loss_params = params.mel_loss_params
|
||||
mel_loss_params.update(
|
||||
frame_length=params.frame_length,
|
||||
frame_shift=params.frame_shift,
|
||||
)
|
||||
mel_loss_params = {
|
||||
"n_mels": params.n_mels,
|
||||
"frame_length": params.frame_length,
|
||||
"frame_shift": params.frame_shift,
|
||||
}
|
||||
model = VITS(
|
||||
vocab_size=params.vocab_size,
|
||||
feature_dim=params.feature_dim,
|
||||
@ -381,18 +333,22 @@ def train_one_epoch(
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The model for training.
|
||||
optimizer:
|
||||
The optimizer we are using.
|
||||
scheduler:
|
||||
The learning rate scheduler, we call step() every step.
|
||||
tokenizer:
|
||||
Used to convert text to phonemes.
|
||||
optimizer_g:
|
||||
The optimizer for generator.
|
||||
optimizer_d:
|
||||
The optimizer for discriminator.
|
||||
scheduler_g:
|
||||
The learning rate scheduler for generator, we call step() every epoch.
|
||||
scheduler_d:
|
||||
The learning rate scheduler for discriminator, we call step() every epoch.
|
||||
train_dl:
|
||||
Dataloader for the training dataset.
|
||||
valid_dl:
|
||||
Dataloader for the validation dataset.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
@ -404,7 +360,7 @@ def train_one_epoch(
|
||||
model.train()
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
|
||||
# used to summary the stats over iterations
|
||||
# used to summary the stats over iterations in one epoch
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
saved_bad_model = False
|
||||
@ -433,7 +389,6 @@ def train_one_epoch(
|
||||
loss_info = MetricsTracker()
|
||||
loss_info['samples'] = batch_size
|
||||
|
||||
return_sample = params.batch_idx_train % params.log_interval == 0
|
||||
try:
|
||||
with autocast(enabled=params.use_fp16):
|
||||
# forward discriminator
|
||||
@ -463,13 +418,11 @@ def train_one_epoch(
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=True,
|
||||
return_sample=return_sample,
|
||||
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||
)
|
||||
for k, v in stats_g.items():
|
||||
if "return_sample" not in k:
|
||||
if "returned_sample" not in k:
|
||||
loss_info[k] = v * batch_size
|
||||
if return_sample:
|
||||
speech_hat_, speech_, mel_hat_, mel_ = stats_g["return_sample"]
|
||||
# update generator
|
||||
optimizer_g.zero_grad()
|
||||
scaler.scale(loss_g).backward()
|
||||
@ -477,7 +430,6 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
|
||||
# summary stats
|
||||
# tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
tot_loss = tot_loss + loss_info
|
||||
except: # noqa
|
||||
save_bad_model()
|
||||
@ -486,37 +438,12 @@ def train_one_epoch(
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
|
||||
if (
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_g=scheduler_g,
|
||||
scheduler_d=scheduler_d,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
# if batch_idx % 100 == 0 and params.use_fp16:
|
||||
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
# of the grad scaler is configurable, but we can't configure it to have different
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
|
||||
# if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
|
||||
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
if cur_grad_scale < 0.01:
|
||||
@ -530,7 +457,6 @@ def train_one_epoch(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
|
||||
# if batch_idx % params.log_interval == 0:
|
||||
if params.batch_idx_train % params.log_interval == 0:
|
||||
cur_lr_g = max(scheduler_g.get_last_lr())
|
||||
cur_lr_d = max(scheduler_d.get_last_lr())
|
||||
@ -561,7 +487,8 @@ def train_one_epoch(
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
)
|
||||
if return_sample:
|
||||
if "returned_sample" in stats_g:
|
||||
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
|
||||
tb_writer.add_audio(
|
||||
"train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate
|
||||
)
|
||||
@ -575,7 +502,6 @@ def train_one_epoch(
|
||||
"train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC'
|
||||
)
|
||||
|
||||
# if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info, (speech_hat, speech) = compute_validation_loss(
|
||||
@ -615,14 +541,14 @@ def compute_validation_loss(
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> MetricsTracker:
|
||||
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
|
||||
# used to summary the stats over iterations
|
||||
tot_loss = MetricsTracker()
|
||||
return_sample = None
|
||||
returned_sample = None
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
@ -667,12 +593,14 @@ def compute_validation_loss(
|
||||
# infer for first batch:
|
||||
if batch_idx == 0 and rank == 0:
|
||||
inner_model = model.module if isinstance(model, DDP) else model
|
||||
audio_pred, _, duration = inner_model.inference(text=tokens[0, :tokens_lens[0].item()])
|
||||
audio_pred, _, duration = inner_model.inference(
|
||||
text=tokens[0, :tokens_lens[0].item()]
|
||||
)
|
||||
audio_pred = audio_pred.data.cpu().numpy()
|
||||
audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
||||
assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred))
|
||||
audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy()
|
||||
return_sample = (audio_pred, audio_gt)
|
||||
returned_sample = (audio_pred, audio_gt)
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(device)
|
||||
@ -682,7 +610,7 @@ def compute_validation_loss(
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss, return_sample
|
||||
return tot_loss, returned_sample
|
||||
|
||||
|
||||
def scan_pessimistic_batches_for_oom(
|
||||
@ -805,18 +733,10 @@ def run(rank, world_size, args):
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer_g = torch.optim.AdamW(
|
||||
generator.parameters(),
|
||||
lr=params.lr,
|
||||
betas=(0.8, 0.99),
|
||||
eps=1e-9,
|
||||
# weight_decay=0,
|
||||
generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
|
||||
)
|
||||
optimizer_d = torch.optim.AdamW(
|
||||
discriminator.parameters(),
|
||||
lr=params.lr,
|
||||
betas=(0.8, 0.99),
|
||||
eps=1e-9,
|
||||
# weight_decay=0,
|
||||
discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
|
||||
)
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
|
||||
@ -852,16 +772,8 @@ def run(rank, world_size, args):
|
||||
|
||||
train_cuts = ljspeech.train_cuts()
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
# We only load the sampler's state dict when it loads a checkpoint
|
||||
# saved in the middle of an epoch
|
||||
sampler_state_dict = checkpoints["sampler"]
|
||||
else:
|
||||
sampler_state_dict = None
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
@ -870,13 +782,10 @@ def run(rank, world_size, args):
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
train_dl = ljspeech.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
train_dl = ljspeech.train_dataloaders(train_cuts)
|
||||
|
||||
valid_cuts = ljspeech.valid_cuts()
|
||||
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
|
||||
@ -902,11 +811,11 @@ def run(rank, world_size, args):
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -927,27 +836,28 @@ def run(rank, world_size, args):
|
||||
diagnostic.print_diagnostics()
|
||||
break
|
||||
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
save_checkpoint(
|
||||
filename=filename,
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_g=scheduler_g,
|
||||
scheduler_d=scheduler_d,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
if rank == 0:
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
if epoch % params.save_every_n == 0:
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
save_checkpoint(
|
||||
filename=filename,
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_g=scheduler_g,
|
||||
scheduler_d=scheduler_d,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
if rank == 0:
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
# step per epoch
|
||||
scheduler_g.step()
|
||||
|
@ -1,4 +1,5 @@
|
||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py
|
||||
|
||||
"""Flow-related transformation.
|
||||
|
||||
This code is derived from https://github.com/bayesiains/nflows.
|
||||
|
@ -1,32 +1,35 @@
|
||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Function to get random segments."""
|
||||
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import collections
|
||||
import logging
|
||||
import re
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from pathlib import Path
|
||||
from phonemizer import phonemize
|
||||
from symbols import symbol_table
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from unidecode import unidecode
|
||||
|
||||
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
|
||||
def get_random_segments(
|
||||
x: torch.Tensor,
|
||||
x_lengths: torch.Tensor,
|
||||
@ -55,6 +58,7 @@ def get_random_segments(
|
||||
return segments, start_idxs
|
||||
|
||||
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
|
||||
def get_segments(
|
||||
x: torch.Tensor,
|
||||
start_idxs: torch.Tensor,
|
||||
@ -78,195 +82,41 @@ def get_segments(
|
||||
return segments
|
||||
|
||||
|
||||
# https://github.com/espnet/espnet/blob/master/espnet2/torch_utils/device_funcs.py
|
||||
def force_gatherable(data, device):
|
||||
"""Change object to gatherable in torch.nn.DataParallel recursively
|
||||
|
||||
The difference from to_device() is changing to torch.Tensor if float or int
|
||||
value is found.
|
||||
|
||||
The restriction to the returned value in DataParallel:
|
||||
The object must be
|
||||
- torch.cuda.Tensor
|
||||
- 1 or more dimension. 0-dimension-tensor sends warning.
|
||||
or a list, tuple, dict.
|
||||
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
return {k: force_gatherable(v, device) for k, v in data.items()}
|
||||
# DataParallel can't handle NamedTuple well
|
||||
elif isinstance(data, tuple) and type(data) is not tuple:
|
||||
return type(data)(*[force_gatherable(o, device) for o in data])
|
||||
elif isinstance(data, (list, tuple, set)):
|
||||
return type(data)(force_gatherable(v, device) for v in data)
|
||||
elif isinstance(data, np.ndarray):
|
||||
return force_gatherable(torch.from_numpy(data), device)
|
||||
elif isinstance(data, torch.Tensor):
|
||||
if data.dim() == 0:
|
||||
# To 1-dim array
|
||||
data = data[None]
|
||||
return data.to(device)
|
||||
elif isinstance(data, float):
|
||||
return torch.tensor([data], dtype=torch.float, device=device)
|
||||
elif isinstance(data, int):
|
||||
return torch.tensor([data], dtype=torch.long, device=device)
|
||||
elif data is None:
|
||||
return None
|
||||
else:
|
||||
warnings.warn(f"{type(data)} may not be gatherable by DataParallel")
|
||||
return data
|
||||
|
||||
|
||||
# The following codes are based on https://github.com/jaywalnut310/vits
|
||||
|
||||
# Regular expression matching whitespace:
|
||||
_whitespace_re = re.compile(r'\s+')
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('mrs', 'misess'),
|
||||
('mr', 'mister'),
|
||||
('dr', 'doctor'),
|
||||
('st', 'saint'),
|
||||
('co', 'company'),
|
||||
('jr', 'junior'),
|
||||
('maj', 'major'),
|
||||
('gen', 'general'),
|
||||
('drs', 'doctors'),
|
||||
('rev', 'reverend'),
|
||||
('lt', 'lieutenant'),
|
||||
('hon', 'honorable'),
|
||||
('sgt', 'sergeant'),
|
||||
('capt', 'captain'),
|
||||
('esq', 'esquire'),
|
||||
('ltd', 'limited'),
|
||||
('col', 'colonel'),
|
||||
('ft', 'fort'),
|
||||
]]
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, ' ', text)
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
return unidecode(text)
|
||||
|
||||
|
||||
def text_clean(text):
|
||||
'''Pipeline for English text, including abbreviation expansion. + punctuation + stress.
|
||||
|
||||
Returns:
|
||||
A string of phonemes.
|
||||
'''
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_abbreviations(text)
|
||||
phonemes = phonemize(
|
||||
text,
|
||||
language='en-us',
|
||||
backend='espeak',
|
||||
strip=True,
|
||||
preserve_punctuation=True,
|
||||
with_stress=True,
|
||||
)
|
||||
phonemes = collapse_whitespace(phonemes)
|
||||
return phonemes
|
||||
|
||||
|
||||
# Mappings from symbol to numeric ID and vice versa:
|
||||
symbol_to_id = {s: i for i, s in enumerate(symbol_table)}
|
||||
id_to_symbol = {i: s for i, s in enumerate(symbol_table)}
|
||||
|
||||
|
||||
# def text_to_sequence(text: str) -> List[int]:
|
||||
# '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
# '''
|
||||
# cleaned_text = text_clean(text)
|
||||
# sequence = [symbol_to_id[symbol] for symbol in cleaned_text]
|
||||
# return sequence
|
||||
#
|
||||
#
|
||||
# def sequence_to_text(sequence: List[int]) -> str:
|
||||
# '''Converts a sequence of IDs back to a string'''
|
||||
# result = ''.join(id_to_symbol[symbol_id] for symbol_id in sequence)
|
||||
# return result
|
||||
|
||||
|
||||
# from https://github.com/jaywalnut310/vit://github.com/jaywalnut310/vits/blob/main/commons.py
|
||||
def intersperse(sequence, item=0):
|
||||
result = [item] * (len(sequence) * 2 + 1)
|
||||
result[1::2] = sequence
|
||||
return result
|
||||
|
||||
|
||||
def prepare_token_batch(
|
||||
texts: List[str],
|
||||
phonemes: Optional[List[str]] = None,
|
||||
intersperse_blank: bool = True,
|
||||
blank_id: int = 0,
|
||||
pad_id: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Convert a list of text strings into a batch of symbol tokens with padding.
|
||||
Args:
|
||||
texts: list of text strings
|
||||
intersperse_blank: whether to intersperse blank tokens in the converted token sequence.
|
||||
blank_id: index of blank token
|
||||
pad_id: padding index
|
||||
"""
|
||||
if phonemes is None:
|
||||
# normalize text
|
||||
normalized_texts = []
|
||||
for text in texts:
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_abbreviations(text)
|
||||
normalized_texts.append(text)
|
||||
# from https://github.com/jaywalnut310/vits/blob/main/utils.py
|
||||
MATPLOTLIB_FLAG = False
|
||||
|
||||
# convert to phonemes
|
||||
phonemes = phonemize(
|
||||
normalized_texts,
|
||||
language='en-us',
|
||||
backend='espeak',
|
||||
strip=True,
|
||||
preserve_punctuation=True,
|
||||
with_stress=True,
|
||||
)
|
||||
phonemes = [collapse_whitespace(sequence) for sequence in phonemes]
|
||||
|
||||
# convert to symbol ids
|
||||
lengths = []
|
||||
sequences = []
|
||||
skip = False
|
||||
for idx, sequence in enumerate(phonemes):
|
||||
try:
|
||||
sequence = [symbol_to_id[symbol] for symbol in sequence]
|
||||
except Exception:
|
||||
# print(texts[idx])
|
||||
# print(normalized_texts[idx])
|
||||
print(phonemes[idx])
|
||||
skip = True
|
||||
if intersperse_blank:
|
||||
sequence = intersperse(sequence, blank_id)
|
||||
try:
|
||||
sequences.append(torch.tensor(sequence, dtype=torch.int64))
|
||||
except Exception:
|
||||
print(sequence)
|
||||
skip = True
|
||||
lengths.append(len(sequence))
|
||||
def plot_feature(spectrogram):
|
||||
global MATPLOTLIB_FLAG
|
||||
if not MATPLOTLIB_FLAG:
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
MATPLOTLIB_FLAG = True
|
||||
mpl_logger = logging.getLogger('matplotlib')
|
||||
mpl_logger.setLevel(logging.WARNING)
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
|
||||
sequences = pad_sequence(sequences, batch_first=True, padding_value=pad_id)
|
||||
lengths = torch.tensor(lengths, dtype=torch.int64)
|
||||
return sequences, lengths, skip
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
||||
interpolation='none')
|
||||
plt.colorbar(im, ax=ax)
|
||||
plt.xlabel("Frames")
|
||||
plt.ylabel("Channels")
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
||||
|
||||
|
||||
class MetricsTracker(collections.defaultdict):
|
||||
@ -413,106 +263,3 @@ def save_checkpoint(
|
||||
checkpoint[k] = v
|
||||
|
||||
torch.save(checkpoint, filename)
|
||||
|
||||
|
||||
def save_checkpoint_with_global_batch_idx(
|
||||
out_dir: Path,
|
||||
global_batch_idx: int,
|
||||
model: Union[nn.Module, DDP],
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
optimizer_g: Optional[Optimizer] = None,
|
||||
optimizer_d: Optional[Optimizer] = None,
|
||||
scheduler_g: Optional[LRSchedulerType] = None,
|
||||
scheduler_d: Optional[LRSchedulerType] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
):
|
||||
"""Save training info after processing given number of batches.
|
||||
|
||||
Args:
|
||||
out_dir:
|
||||
The directory to save the checkpoint.
|
||||
global_batch_idx:
|
||||
The number of batches processed so far from the very start of the
|
||||
training. The saved checkpoint will have the following filename:
|
||||
f'out_dir / checkpoint-{global_batch_idx}.pt'
|
||||
model:
|
||||
The neural network model whose `state_dict` will be saved in the
|
||||
checkpoint.
|
||||
params:
|
||||
A dict of training configurations to be saved.
|
||||
optimizer_g:
|
||||
The optimizer for generator used in the training.
|
||||
Its `state_dict` will be saved.
|
||||
optimizer_d:
|
||||
The optimizer for discriminator used in the training.
|
||||
Its `state_dict` will be saved.
|
||||
scheduler_g:
|
||||
The learning rate scheduler for generator used in the training.
|
||||
Its `state_dict` will be saved.
|
||||
scheduler_d:
|
||||
The learning rate scheduler for discriminator used in the training.
|
||||
Its `state_dict` will be saved.
|
||||
scaler:
|
||||
The scaler used for mix precision training. Its `state_dict` will
|
||||
be saved.
|
||||
sampler:
|
||||
The sampler used in the training dataset.
|
||||
rank:
|
||||
The rank ID used in DDP training of the current node. Set it to 0
|
||||
if DDP is not used.
|
||||
"""
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
|
||||
save_checkpoint(
|
||||
filename=filename,
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_g=scheduler_g,
|
||||
scheduler_d=scheduler_d,
|
||||
scaler=scaler,
|
||||
sampler=sampler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
|
||||
# def plot_feature(feature):
|
||||
# """
|
||||
# Display the feature matrix as an image. Requires matplotlib to be installed.
|
||||
# """
|
||||
# import matplotlib.pyplot as plt
|
||||
#
|
||||
# feature = np.flip(feature.transpose(1, 0), 0)
|
||||
# return plt.matshow(feature)
|
||||
|
||||
MATPLOTLIB_FLAG = False
|
||||
|
||||
|
||||
def plot_feature(spectrogram):
|
||||
global MATPLOTLIB_FLAG
|
||||
if not MATPLOTLIB_FLAG:
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
MATPLOTLIB_FLAG = True
|
||||
mpl_logger = logging.getLogger('matplotlib')
|
||||
mpl_logger.setLevel(logging.WARNING)
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
||||
interpolation='none')
|
||||
plt.colorbar(im, ax=ax)
|
||||
plt.xlabel("Frames")
|
||||
plt.ylabel("Channels")
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
||||
|
@ -1,11 +1,11 @@
|
||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py
|
||||
# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""VITS module for GAN-TTS task."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -247,7 +247,7 @@ class VITS(nn.Module):
|
||||
spembs: Optional[torch.Tensor] = None,
|
||||
lids: Optional[torch.Tensor] = None,
|
||||
forward_generator: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||
"""Perform generator forward.
|
||||
|
||||
Args:
|
||||
@ -263,12 +263,8 @@ class VITS(nn.Module):
|
||||
forward_generator (bool): Whether to forward generator.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]:
|
||||
- loss (Tensor): Loss scalar tensor.
|
||||
- stats (Dict[str, float]): Statistics to be monitored.
|
||||
- weight (Tensor): Weight tensor to summarize losses.
|
||||
- optim_idx (int): Optimizer index (0 for G and 1 for D).
|
||||
|
||||
- loss (Tensor): Loss scalar tensor.
|
||||
- stats (Dict[str, float]): Statistics to be monitored.
|
||||
"""
|
||||
if forward_generator:
|
||||
return self._forward_generator(
|
||||
@ -308,7 +304,7 @@ class VITS(nn.Module):
|
||||
sids: Optional[torch.Tensor] = None,
|
||||
spembs: Optional[torch.Tensor] = None,
|
||||
lids: Optional[torch.Tensor] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||
"""Perform generator forward.
|
||||
|
||||
Args:
|
||||
@ -323,12 +319,8 @@ class VITS(nn.Module):
|
||||
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]:
|
||||
* loss (Tensor): Loss scalar tensor.
|
||||
* stats (Dict[str, float]): Statistics to be monitored.
|
||||
* weight (Tensor): Weight tensor to summarize losses.
|
||||
* optim_idx (int): Optimizer index (0 for G and 1 for D).
|
||||
|
||||
* loss (Tensor): Loss scalar tensor.
|
||||
* stats (Dict[str, float]): Statistics to be monitored.
|
||||
"""
|
||||
# setup
|
||||
feats = feats.transpose(1, 2)
|
||||
@ -399,7 +391,7 @@ class VITS(nn.Module):
|
||||
)
|
||||
|
||||
if return_sample:
|
||||
stats["return_sample"] = (
|
||||
stats["returned_sample"] = (
|
||||
speech_hat_[0].data.cpu().numpy(),
|
||||
speech_[0].data.cpu().numpy(),
|
||||
mel_hat_[0].data.cpu().numpy(),
|
||||
@ -423,7 +415,7 @@ class VITS(nn.Module):
|
||||
sids: Optional[torch.Tensor] = None,
|
||||
spembs: Optional[torch.Tensor] = None,
|
||||
lids: Optional[torch.Tensor] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||
"""Perform discriminator forward.
|
||||
|
||||
Args:
|
||||
@ -438,12 +430,8 @@ class VITS(nn.Module):
|
||||
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]:
|
||||
* loss (Tensor): Loss scalar tensor.
|
||||
* stats (Dict[str, float]): Statistics to be monitored.
|
||||
* weight (Tensor): Weight tensor to summarize losses.
|
||||
* optim_idx (int): Optimizer index (0 for G and 1 for D).
|
||||
|
||||
* loss (Tensor): Loss scalar tensor.
|
||||
* stats (Dict[str, float]): Statistics to be monitored.
|
||||
"""
|
||||
# setup
|
||||
feats = feats.transpose(1, 2)
|
||||
@ -511,8 +499,8 @@ class VITS(nn.Module):
|
||||
alpha: float = 1.0,
|
||||
max_len: Optional[int] = None,
|
||||
use_teacher_forcing: bool = False,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Run inference.
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Run inference for single sample.
|
||||
|
||||
Args:
|
||||
text (Tensor): Input text index tensor (T_text,).
|
||||
@ -528,11 +516,9 @@ class VITS(nn.Module):
|
||||
use_teacher_forcing (bool): Whether to use teacher forcing.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]:
|
||||
* wav (Tensor): Generated waveform tensor (T_wav,).
|
||||
* att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
|
||||
* duration (Tensor): Predicted duration tensor (T_text,).
|
||||
|
||||
* wav (Tensor): Generated waveform tensor (T_wav,).
|
||||
* att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
|
||||
* duration (Tensor): Predicted duration tensor (T_text,).
|
||||
"""
|
||||
# setup
|
||||
text = text[None]
|
||||
@ -593,8 +579,8 @@ class VITS(nn.Module):
|
||||
alpha: float = 1.0,
|
||||
max_len: Optional[int] = None,
|
||||
use_teacher_forcing: bool = False,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Run inference.
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Run inference for one batch.
|
||||
|
||||
Args:
|
||||
text (Tensor): Input text index tensor (B, T_text).
|
||||
@ -605,11 +591,9 @@ class VITS(nn.Module):
|
||||
max_len (Optional[int]): Maximum length.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]:
|
||||
* wav (Tensor): Generated waveform tensor (B, T_wav).
|
||||
* att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text).
|
||||
* duration (Tensor): Predicted duration tensor (B, T_text).
|
||||
|
||||
* wav (Tensor): Generated waveform tensor (B, T_wav).
|
||||
* att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text).
|
||||
* duration (Tensor): Predicted duration tensor (B, T_text).
|
||||
"""
|
||||
# inference
|
||||
wav, att_w, dur = self.generator.inference(
|
||||
|
@ -1,4 +1,4 @@
|
||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user