mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
modify training script, clean codes
This commit is contained in:
parent
8d09f8e6bf
commit
04c6ecbaa1
@ -1,80 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
This script split the LJSpeech dataset cuts into three sets:
|
|
||||||
- training, 12500
|
|
||||||
- validation, 100
|
|
||||||
- test, 500
|
|
||||||
The numbers are from https://arxiv.org/pdf/2106.06103.pdf
|
|
||||||
|
|
||||||
Usage example:
|
|
||||||
python3 ./local/split_subsets.py ./data/spectrogram
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import random
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from lhotse import load_manifest_lazy
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"manifest_dir",
|
|
||||||
type=Path,
|
|
||||||
default=Path("data/spectrogram"),
|
|
||||||
help="Path to the manifest file",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = get_args()
|
|
||||||
|
|
||||||
manifest_dir = Path(args.manifest_dir)
|
|
||||||
prefix = "ljspeech"
|
|
||||||
suffix = "jsonl.gz"
|
|
||||||
# all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all.{suffix}")
|
|
||||||
all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all_phonemized.{suffix}")
|
|
||||||
|
|
||||||
cut_ids = list(all_cuts.ids)
|
|
||||||
random.shuffle(cut_ids)
|
|
||||||
|
|
||||||
train_cuts = all_cuts.subset(cut_ids=cut_ids[:12500])
|
|
||||||
valid_cuts = all_cuts.subset(cut_ids=cut_ids[12500:12500 + 100])
|
|
||||||
test_cuts = all_cuts.subset(cut_ids=cut_ids[12500 + 100:])
|
|
||||||
assert len(train_cuts) == 12500, "expected 12500 cuts for training but got len(train_cuts)"
|
|
||||||
assert len(valid_cuts) == 100, "expected 100 cuts but for validation but got len(valid_cuts)"
|
|
||||||
assert len(test_cuts) == 500, "expected 500 cuts for test but got len(test_cuts)"
|
|
||||||
|
|
||||||
train_cuts.to_file(manifest_dir / f"{prefix}_cuts_train.{suffix}")
|
|
||||||
valid_cuts.to_file(manifest_dir / f"{prefix}_cuts_valid.{suffix}")
|
|
||||||
test_cuts.to_file(manifest_dir / f"{prefix}_cuts_test.{suffix}")
|
|
||||||
|
|
||||||
logging.info("Splitted into three sets: training (12500), validation (100), and test (500)")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
|
|
||||||
main()
|
|
||||||
@ -9,8 +9,7 @@ nj=1
|
|||||||
stage=-1
|
stage=-1
|
||||||
stop_stage=100
|
stop_stage=100
|
||||||
|
|
||||||
# dl_dir=$PWD/download
|
dl_dir=$PWD/download
|
||||||
dl_dir=/star-data/zengwei/download/ljspeech/
|
|
||||||
|
|
||||||
. shared/parse_options.sh || exit 1
|
. shared/parse_options.sh || exit 1
|
||||||
|
|
||||||
@ -66,22 +65,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|
||||||
# log "Stage 3: Phonemize the transcripts for LJSpeech"
|
|
||||||
# if [ ! -e data/spectrogram/.ljspeech_phonemized.done ]; then
|
|
||||||
# ./local/phonemize_text.py data/spectrogram
|
|
||||||
# touch data/spectrogram/.ljspeech_phonemized.done
|
|
||||||
# fi
|
|
||||||
# fi
|
|
||||||
|
|
||||||
# if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|
||||||
# log "Stage 4: Split the LJSpeech cuts into three sets"
|
|
||||||
# if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
|
|
||||||
# ./local/split_subsets.py data/spectrogram
|
|
||||||
# touch data/spectrogram/.ljspeech_split.done
|
|
||||||
# fi
|
|
||||||
# fi
|
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
log "Stage 3: Split the LJSpeech cuts into train, valid and test sets"
|
log "Stage 3: Split the LJSpeech cuts into train, valid and test sets"
|
||||||
if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
|
if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
|
||||||
@ -94,6 +77,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
lhotse subset --last 500 \
|
lhotse subset --last 500 \
|
||||||
data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \
|
data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \
|
||||||
data/spectrogram/ljspeech_cuts_test.jsonl.gz
|
data/spectrogram/ljspeech_cuts_test.jsonl.gz
|
||||||
|
|
||||||
rm data/spectrogram/ljspeech_cuts_validtest.jsonl.gz
|
rm data/spectrogram/ljspeech_cuts_validtest.jsonl.gz
|
||||||
|
|
||||||
n=$(( $(gunzip -c data/spectrogram/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 ))
|
n=$(( $(gunzip -c data/spectrogram/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 ))
|
||||||
|
|||||||
@ -1,161 +0,0 @@
|
|||||||
import math
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def init_weights(m, mean=0.0, std=0.01):
|
|
||||||
classname = m.__class__.__name__
|
|
||||||
if classname.find("Conv") != -1:
|
|
||||||
m.weight.data.normal_(mean, std)
|
|
||||||
|
|
||||||
|
|
||||||
def get_padding(kernel_size, dilation=1):
|
|
||||||
return int((kernel_size*dilation - dilation)/2)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_pad_shape(pad_shape):
|
|
||||||
l = pad_shape[::-1]
|
|
||||||
pad_shape = [item for sublist in l for item in sublist]
|
|
||||||
return pad_shape
|
|
||||||
|
|
||||||
|
|
||||||
def intersperse(lst, item):
|
|
||||||
result = [item] * (len(lst) * 2 + 1)
|
|
||||||
result[1::2] = lst
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
|
||||||
"""KL(P||Q)"""
|
|
||||||
kl = (logs_q - logs_p) - 0.5
|
|
||||||
kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
|
|
||||||
return kl
|
|
||||||
|
|
||||||
|
|
||||||
def rand_gumbel(shape):
|
|
||||||
"""Sample from the Gumbel distribution, protect from overflows."""
|
|
||||||
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
|
||||||
return -torch.log(-torch.log(uniform_samples))
|
|
||||||
|
|
||||||
|
|
||||||
def rand_gumbel_like(x):
|
|
||||||
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
|
||||||
return g
|
|
||||||
|
|
||||||
|
|
||||||
def slice_segments(x, ids_str, segment_size=4):
|
|
||||||
ret = torch.zeros_like(x[:, :, :segment_size])
|
|
||||||
for i in range(x.size(0)):
|
|
||||||
idx_str = ids_str[i]
|
|
||||||
idx_end = idx_str + segment_size
|
|
||||||
ret[i] = x[i, :, idx_str:idx_end]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
|
||||||
b, d, t = x.size()
|
|
||||||
if x_lengths is None:
|
|
||||||
x_lengths = t
|
|
||||||
ids_str_max = x_lengths - segment_size + 1
|
|
||||||
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
|
||||||
ret = slice_segments(x, ids_str, segment_size)
|
|
||||||
return ret, ids_str
|
|
||||||
|
|
||||||
|
|
||||||
def get_timing_signal_1d(
|
|
||||||
length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
|
||||||
position = torch.arange(length, dtype=torch.float)
|
|
||||||
num_timescales = channels // 2
|
|
||||||
log_timescale_increment = (
|
|
||||||
math.log(float(max_timescale) / float(min_timescale)) /
|
|
||||||
(num_timescales - 1))
|
|
||||||
inv_timescales = min_timescale * torch.exp(
|
|
||||||
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
|
|
||||||
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
|
||||||
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
|
||||||
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
|
||||||
signal = signal.view(1, channels, length)
|
|
||||||
return signal
|
|
||||||
|
|
||||||
|
|
||||||
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
|
||||||
b, channels, length = x.size()
|
|
||||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
|
||||||
return x + signal.to(dtype=x.dtype, device=x.device)
|
|
||||||
|
|
||||||
|
|
||||||
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
|
||||||
b, channels, length = x.size()
|
|
||||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
|
||||||
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
|
||||||
|
|
||||||
|
|
||||||
def subsequent_mask(length):
|
|
||||||
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
|
||||||
n_channels_int = n_channels[0]
|
|
||||||
in_act = input_a + input_b
|
|
||||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
|
||||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
|
||||||
acts = t_act * s_act
|
|
||||||
return acts
|
|
||||||
|
|
||||||
|
|
||||||
def convert_pad_shape(pad_shape):
|
|
||||||
l = pad_shape[::-1]
|
|
||||||
pad_shape = [item for sublist in l for item in sublist]
|
|
||||||
return pad_shape
|
|
||||||
|
|
||||||
|
|
||||||
def shift_1d(x):
|
|
||||||
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def sequence_mask(length, max_length=None):
|
|
||||||
if max_length is None:
|
|
||||||
max_length = length.max()
|
|
||||||
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
|
||||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_path(duration, mask):
|
|
||||||
"""
|
|
||||||
duration: [b, 1, t_x]
|
|
||||||
mask: [b, 1, t_y, t_x]
|
|
||||||
"""
|
|
||||||
device = duration.device
|
|
||||||
|
|
||||||
b, _, t_y, t_x = mask.shape
|
|
||||||
cum_duration = torch.cumsum(duration, -1)
|
|
||||||
|
|
||||||
cum_duration_flat = cum_duration.view(b * t_x)
|
|
||||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
|
||||||
path = path.view(b, t_x, t_y)
|
|
||||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
|
||||||
path = path.unsqueeze(1).transpose(2,3) * mask
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
|
||||||
if isinstance(parameters, torch.Tensor):
|
|
||||||
parameters = [parameters]
|
|
||||||
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
|
||||||
norm_type = float(norm_type)
|
|
||||||
if clip_value is not None:
|
|
||||||
clip_value = float(clip_value)
|
|
||||||
|
|
||||||
total_norm = 0
|
|
||||||
for p in parameters:
|
|
||||||
param_norm = p.grad.data.norm(norm_type)
|
|
||||||
total_norm += param_norm.item() ** norm_type
|
|
||||||
if clip_value is not None:
|
|
||||||
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
|
||||||
total_norm = total_norm ** (1. / norm_type)
|
|
||||||
return total_norm
|
|
||||||
@ -1,4 +1,4 @@
|
|||||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py
|
||||||
|
|
||||||
# Copyright 2021 Tomoki Hayashi
|
# Copyright 2021 Tomoki Hayashi
|
||||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|||||||
@ -1,416 +0,0 @@
|
|||||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple
|
|
||||||
|
|
||||||
import librosa
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
|
||||||
|
|
||||||
|
|
||||||
# From https://github.com/espnet/espnet/blob/master/espnet2/layers/stft.py
|
|
||||||
class Stft(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
n_fft: int = 512,
|
|
||||||
win_length: int = None,
|
|
||||||
hop_length: int = 128,
|
|
||||||
window: Optional[str] = "hann",
|
|
||||||
center: bool = True,
|
|
||||||
normalized: bool = False,
|
|
||||||
onesided: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.n_fft = n_fft
|
|
||||||
if win_length is None:
|
|
||||||
self.win_length = n_fft
|
|
||||||
else:
|
|
||||||
self.win_length = win_length
|
|
||||||
self.hop_length = hop_length
|
|
||||||
self.center = center
|
|
||||||
self.normalized = normalized
|
|
||||||
self.onesided = onesided
|
|
||||||
if window is not None and not hasattr(torch, f"{window}_window"):
|
|
||||||
raise ValueError(f"{window} window is not implemented")
|
|
||||||
self.window = window
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return (
|
|
||||||
f"n_fft={self.n_fft}, "
|
|
||||||
f"win_length={self.win_length}, "
|
|
||||||
f"hop_length={self.hop_length}, "
|
|
||||||
f"center={self.center}, "
|
|
||||||
f"normalized={self.normalized}, "
|
|
||||||
f"onesided={self.onesided}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, input: torch.Tensor, ilens: torch.Tensor = None
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
||||||
"""STFT forward function.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input: (Batch, Nsamples) or (Batch, Nsample, Channels)
|
|
||||||
ilens: (Batch)
|
|
||||||
Returns:
|
|
||||||
output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
|
|
||||||
|
|
||||||
"""
|
|
||||||
bs = input.size(0)
|
|
||||||
if input.dim() == 3:
|
|
||||||
multi_channel = True
|
|
||||||
# input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
|
|
||||||
input = input.transpose(1, 2).reshape(-1, input.size(1))
|
|
||||||
else:
|
|
||||||
multi_channel = False
|
|
||||||
|
|
||||||
# NOTE(kamo):
|
|
||||||
# The default behaviour of torch.stft is compatible with librosa.stft
|
|
||||||
# about padding and scaling.
|
|
||||||
# Note that it's different from scipy.signal.stft
|
|
||||||
|
|
||||||
# output: (Batch, Freq, Frames, 2=real_imag)
|
|
||||||
# or (Batch, Channel, Freq, Frames, 2=real_imag)
|
|
||||||
if self.window is not None:
|
|
||||||
window_func = getattr(torch, f"{self.window}_window")
|
|
||||||
window = window_func(
|
|
||||||
self.win_length, dtype=input.dtype, device=input.device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
window = None
|
|
||||||
|
|
||||||
# For the compatibility of ARM devices, which do not support
|
|
||||||
# torch.stft() due to the lack of MKL (on older pytorch versions),
|
|
||||||
# there is an alternative replacement implementation with librosa.
|
|
||||||
# Note: pytorch >= 1.10.0 now has native support for FFT and STFT
|
|
||||||
# on all cpu targets including ARM.
|
|
||||||
if input.is_cuda or torch.backends.mkl.is_available():
|
|
||||||
stft_kwargs = dict(
|
|
||||||
n_fft=self.n_fft,
|
|
||||||
win_length=self.win_length,
|
|
||||||
hop_length=self.hop_length,
|
|
||||||
center=self.center,
|
|
||||||
window=window,
|
|
||||||
normalized=self.normalized,
|
|
||||||
onesided=self.onesided,
|
|
||||||
)
|
|
||||||
stft_kwargs["return_complex"] = True
|
|
||||||
output = torch.stft(input, **stft_kwargs)
|
|
||||||
output = torch.view_as_real(output)
|
|
||||||
else:
|
|
||||||
if self.training:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"stft is implemented with librosa on this device, which does not "
|
|
||||||
"support the training mode."
|
|
||||||
)
|
|
||||||
|
|
||||||
# use stft_kwargs to flexibly control different PyTorch versions' kwargs
|
|
||||||
# note: librosa does not support a win_length that is < n_ftt
|
|
||||||
# but the window can be manually padded (see below).
|
|
||||||
stft_kwargs = dict(
|
|
||||||
n_fft=self.n_fft,
|
|
||||||
win_length=self.n_fft,
|
|
||||||
hop_length=self.hop_length,
|
|
||||||
center=self.center,
|
|
||||||
window=window,
|
|
||||||
pad_mode="reflect",
|
|
||||||
)
|
|
||||||
|
|
||||||
if window is not None:
|
|
||||||
# pad the given window to n_fft
|
|
||||||
n_pad_left = (self.n_fft - window.shape[0]) // 2
|
|
||||||
n_pad_right = self.n_fft - window.shape[0] - n_pad_left
|
|
||||||
stft_kwargs["window"] = torch.cat(
|
|
||||||
[torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0
|
|
||||||
).numpy()
|
|
||||||
else:
|
|
||||||
win_length = (
|
|
||||||
self.win_length if self.win_length is not None else self.n_fft
|
|
||||||
)
|
|
||||||
stft_kwargs["window"] = torch.ones(win_length)
|
|
||||||
|
|
||||||
output = []
|
|
||||||
# iterate over istances in a batch
|
|
||||||
for i, instance in enumerate(input):
|
|
||||||
stft = librosa.stft(input[i].numpy(), **stft_kwargs)
|
|
||||||
output.append(torch.tensor(np.stack([stft.real, stft.imag], -1)))
|
|
||||||
output = torch.stack(output, 0)
|
|
||||||
if not self.onesided:
|
|
||||||
len_conj = self.n_fft - output.shape[1]
|
|
||||||
conj = output[:, 1 : 1 + len_conj].flip(1)
|
|
||||||
conj[:, :, :, -1].data *= -1
|
|
||||||
output = torch.cat([output, conj], 1)
|
|
||||||
if self.normalized:
|
|
||||||
output = output * (stft_kwargs["window"].shape[0] ** (-0.5))
|
|
||||||
|
|
||||||
# output: (Batch, Freq, Frames, 2=real_imag)
|
|
||||||
# -> (Batch, Frames, Freq, 2=real_imag)
|
|
||||||
output = output.transpose(1, 2)
|
|
||||||
if multi_channel:
|
|
||||||
# output: (Batch * Channel, Frames, Freq, 2=real_imag)
|
|
||||||
# -> (Batch, Frame, Channel, Freq, 2=real_imag)
|
|
||||||
output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
|
|
||||||
1, 2
|
|
||||||
)
|
|
||||||
|
|
||||||
if ilens is not None:
|
|
||||||
if self.center:
|
|
||||||
pad = self.n_fft // 2
|
|
||||||
ilens = ilens + 2 * pad
|
|
||||||
|
|
||||||
olens = (
|
|
||||||
torch.div(ilens - self.n_fft, self.hop_length, rounding_mode="trunc")
|
|
||||||
+ 1
|
|
||||||
)
|
|
||||||
output.masked_fill_(make_pad_mask(olens), 0.0)
|
|
||||||
else:
|
|
||||||
olens = None
|
|
||||||
|
|
||||||
return output, olens
|
|
||||||
|
|
||||||
|
|
||||||
# From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/linear_spectrogram.py
|
|
||||||
class LinearSpectrogram(nn.Module):
|
|
||||||
"""Linear amplitude spectrogram.
|
|
||||||
|
|
||||||
Stft -> amplitude-spec
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
n_fft: int = 1024,
|
|
||||||
win_length: int = None,
|
|
||||||
hop_length: int = 256,
|
|
||||||
window: Optional[str] = "hann",
|
|
||||||
center: bool = True,
|
|
||||||
normalized: bool = False,
|
|
||||||
onesided: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.n_fft = n_fft
|
|
||||||
self.hop_length = hop_length
|
|
||||||
self.win_length = win_length
|
|
||||||
self.window = window
|
|
||||||
self.stft = Stft(
|
|
||||||
n_fft=n_fft,
|
|
||||||
win_length=win_length,
|
|
||||||
hop_length=hop_length,
|
|
||||||
window=window,
|
|
||||||
center=center,
|
|
||||||
normalized=normalized,
|
|
||||||
onesided=onesided,
|
|
||||||
)
|
|
||||||
self.n_fft = n_fft
|
|
||||||
|
|
||||||
def output_size(self) -> int:
|
|
||||||
return self.n_fft // 2 + 1
|
|
||||||
|
|
||||||
def get_parameters(self) -> Dict[str, Any]:
|
|
||||||
"""Return the parameters required by Vocoder."""
|
|
||||||
return dict(
|
|
||||||
n_fft=self.n_fft,
|
|
||||||
n_shift=self.hop_length,
|
|
||||||
win_length=self.win_length,
|
|
||||||
window=self.window,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, input: torch.Tensor, input_lengths: torch.Tensor = None
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
# 1. Stft: time -> time-freq
|
|
||||||
input_stft, feats_lens = self.stft(input, input_lengths)
|
|
||||||
|
|
||||||
assert input_stft.dim() >= 4, input_stft.shape
|
|
||||||
# "2" refers to the real/imag parts of Complex
|
|
||||||
assert input_stft.shape[-1] == 2, input_stft.shape
|
|
||||||
|
|
||||||
# STFT -> Power spectrum -> Amp spectrum
|
|
||||||
# input_stft: (..., F, 2) -> (..., F)
|
|
||||||
input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
|
|
||||||
input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10))
|
|
||||||
return input_amp, feats_lens
|
|
||||||
|
|
||||||
|
|
||||||
# From https://github.com/espnet/espnet/blob/master/espnet2/layers/log_mel.py
|
|
||||||
class LogMel(nn.Module):
|
|
||||||
"""Convert STFT to fbank feats
|
|
||||||
|
|
||||||
The arguments is same as librosa.filters.mel
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fs: number > 0 [scalar] sampling rate of the incoming signal
|
|
||||||
n_fft: int > 0 [scalar] number of FFT components
|
|
||||||
n_mels: int > 0 [scalar] number of Mel bands to generate
|
|
||||||
fmin: float >= 0 [scalar] lowest frequency (in Hz)
|
|
||||||
fmax: float >= 0 [scalar] highest frequency (in Hz).
|
|
||||||
If `None`, use `fmax = fs / 2.0`
|
|
||||||
htk: use HTK formula instead of Slaney
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
fs: int = 16000,
|
|
||||||
n_fft: int = 512,
|
|
||||||
n_mels: int = 80,
|
|
||||||
fmin: float = None,
|
|
||||||
fmax: float = None,
|
|
||||||
htk: bool = False,
|
|
||||||
log_base: float = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
fmin = 0 if fmin is None else fmin
|
|
||||||
fmax = fs / 2 if fmax is None else fmax
|
|
||||||
_mel_options = dict(
|
|
||||||
sr=fs,
|
|
||||||
n_fft=n_fft,
|
|
||||||
n_mels=n_mels,
|
|
||||||
fmin=fmin,
|
|
||||||
fmax=fmax,
|
|
||||||
htk=htk,
|
|
||||||
)
|
|
||||||
self.mel_options = _mel_options
|
|
||||||
self.log_base = log_base
|
|
||||||
|
|
||||||
# Note(kamo): The mel matrix of librosa is different from kaldi.
|
|
||||||
melmat = librosa.filters.mel(**_mel_options)
|
|
||||||
# melmat: (D2, D1) -> (D1, D2)
|
|
||||||
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
feat: torch.Tensor,
|
|
||||||
ilens: torch.Tensor = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
|
|
||||||
mel_feat = torch.matmul(feat, self.melmat)
|
|
||||||
mel_feat = torch.clamp(mel_feat, min=1e-10)
|
|
||||||
|
|
||||||
if self.log_base is None:
|
|
||||||
logmel_feat = mel_feat.log()
|
|
||||||
elif self.log_base == 2.0:
|
|
||||||
logmel_feat = mel_feat.log2()
|
|
||||||
elif self.log_base == 10.0:
|
|
||||||
logmel_feat = mel_feat.log10()
|
|
||||||
else:
|
|
||||||
logmel_feat = mel_feat.log() / torch.log(self.log_base)
|
|
||||||
|
|
||||||
# Zero padding
|
|
||||||
if ilens is not None:
|
|
||||||
logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens), 0.0)
|
|
||||||
else:
|
|
||||||
ilens = feat.new_full(
|
|
||||||
[feat.size(0)], fill_value=feat.size(1), dtype=torch.long
|
|
||||||
)
|
|
||||||
return logmel_feat, ilens
|
|
||||||
|
|
||||||
|
|
||||||
# From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/log_mel_fbank.py
|
|
||||||
class LogMelFbank(nn.Module):
|
|
||||||
"""Conventional frontend structure for TTS.
|
|
||||||
|
|
||||||
Stft -> amplitude-spec -> Log-Mel-Fbank
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
fs: int = 16000,
|
|
||||||
n_fft: int = 1024,
|
|
||||||
win_length: int = None,
|
|
||||||
hop_length: int = 256,
|
|
||||||
window: Optional[str] = "hann",
|
|
||||||
center: bool = True,
|
|
||||||
normalized: bool = False,
|
|
||||||
onesided: bool = True,
|
|
||||||
n_mels: int = 80,
|
|
||||||
fmin: Optional[int] = 80,
|
|
||||||
fmax: Optional[int] = 7600,
|
|
||||||
htk: bool = False,
|
|
||||||
log_base: Optional[float] = 10.0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.fs = fs
|
|
||||||
self.n_mels = n_mels
|
|
||||||
self.n_fft = n_fft
|
|
||||||
self.hop_length = hop_length
|
|
||||||
self.win_length = win_length
|
|
||||||
self.window = window
|
|
||||||
self.fmin = fmin
|
|
||||||
self.fmax = fmax
|
|
||||||
|
|
||||||
self.stft = Stft(
|
|
||||||
n_fft=n_fft,
|
|
||||||
win_length=win_length,
|
|
||||||
hop_length=hop_length,
|
|
||||||
window=window,
|
|
||||||
center=center,
|
|
||||||
normalized=normalized,
|
|
||||||
onesided=onesided,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.logmel = LogMel(
|
|
||||||
fs=fs,
|
|
||||||
n_fft=n_fft,
|
|
||||||
n_mels=n_mels,
|
|
||||||
fmin=fmin,
|
|
||||||
fmax=fmax,
|
|
||||||
htk=htk,
|
|
||||||
log_base=log_base,
|
|
||||||
)
|
|
||||||
|
|
||||||
def output_size(self) -> int:
|
|
||||||
return self.n_mels
|
|
||||||
|
|
||||||
def get_parameters(self) -> Dict[str, Any]:
|
|
||||||
"""Return the parameters required by Vocoder"""
|
|
||||||
return dict(
|
|
||||||
fs=self.fs,
|
|
||||||
n_fft=self.n_fft,
|
|
||||||
n_shift=self.hop_length,
|
|
||||||
window=self.window,
|
|
||||||
n_mels=self.n_mels,
|
|
||||||
win_length=self.win_length,
|
|
||||||
fmin=self.fmin,
|
|
||||||
fmax=self.fmax,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, input: torch.Tensor, input_lengths: torch.Tensor = None
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
# 1. Domain-conversion: e.g. Stft: time -> time-freq
|
|
||||||
input_stft, feats_lens = self.stft(input, input_lengths)
|
|
||||||
|
|
||||||
assert input_stft.dim() >= 4, input_stft.shape
|
|
||||||
# "2" refers to the real/imag parts of Complex
|
|
||||||
assert input_stft.shape[-1] == 2, input_stft.shape
|
|
||||||
|
|
||||||
# NOTE(kamo): We use different definition for log-spec between TTS and ASR
|
|
||||||
# TTS: log_10(abs(stft))
|
|
||||||
# ASR: log_e(power(stft))
|
|
||||||
|
|
||||||
# input_stft: (..., F, 2) -> (..., F)
|
|
||||||
input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
|
|
||||||
input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10))
|
|
||||||
input_feats, _ = self.logmel(input_amp, feats_lens)
|
|
||||||
return input_feats, feats_lens
|
|
||||||
@ -1,4 +1,5 @@
|
|||||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py
|
||||||
|
|
||||||
# Copyright 2021 Tomoki Hayashi
|
# Copyright 2021 Tomoki Hayashi
|
||||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py
|
# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py
|
||||||
|
|
||||||
# Copyright 2021 Tomoki Hayashi
|
# Copyright 2021 Tomoki Hayashi
|
||||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py
|
||||||
|
|
||||||
# Copyright 2021 Tomoki Hayashi
|
# Copyright 2021 Tomoki Hayashi
|
||||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
||||||
# Zengwei Yao)
|
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -17,118 +16,34 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
This script performs model inference on test set.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
./vits/infer.py \
|
||||||
./zipformer/decode.py \
|
--epoch 1000 \
|
||||||
--epoch 28 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
--exp-dir ./zipformer/exp \
|
||||||
--max-duration 600 \
|
--max-duration 500
|
||||||
--decoding-method greedy_search
|
|
||||||
|
|
||||||
(2) beam search (not recommended)
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 28 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method beam_search \
|
|
||||||
--beam-size 4
|
|
||||||
|
|
||||||
(3) modified beam search
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 28 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method modified_beam_search \
|
|
||||||
--beam-size 4
|
|
||||||
|
|
||||||
(4) fast beam search (one best)
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 28 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method fast_beam_search \
|
|
||||||
--beam 20.0 \
|
|
||||||
--max-contexts 8 \
|
|
||||||
--max-states 64
|
|
||||||
|
|
||||||
(5) fast beam search (nbest)
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 28 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method fast_beam_search_nbest \
|
|
||||||
--beam 20.0 \
|
|
||||||
--max-contexts 8 \
|
|
||||||
--max-states 64 \
|
|
||||||
--num-paths 200 \
|
|
||||||
--nbest-scale 0.5
|
|
||||||
|
|
||||||
(6) fast beam search (nbest oracle WER)
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 28 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method fast_beam_search_nbest_oracle \
|
|
||||||
--beam 20.0 \
|
|
||||||
--max-contexts 8 \
|
|
||||||
--max-states 64 \
|
|
||||||
--num-paths 200 \
|
|
||||||
--nbest-scale 0.5
|
|
||||||
|
|
||||||
(7) fast beam search (with LG)
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 28 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method fast_beam_search_nbest_LG \
|
|
||||||
--beam 20.0 \
|
|
||||||
--max-contexts 8 \
|
|
||||||
--max-states 64
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
import os
|
|
||||||
from collections import defaultdict
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
from train import get_model, get_params, prepare_input
|
from train import get_model, get_params
|
||||||
from tokenizer import Tokenizer
|
from tokenizer import Tokenizer
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import load_checkpoint
|
||||||
average_checkpoints,
|
from icefall.utils import AttributeDict, setup_logger
|
||||||
find_checkpoints,
|
|
||||||
load_checkpoint,
|
|
||||||
)
|
|
||||||
from icefall.lexicon import Lexicon
|
|
||||||
from icefall.utils import (
|
|
||||||
AttributeDict,
|
|
||||||
make_pad_mask,
|
|
||||||
setup_logger,
|
|
||||||
store_transcripts,
|
|
||||||
str2bool,
|
|
||||||
write_error_stats,
|
|
||||||
)
|
|
||||||
from tts_datamodule import LJSpeechTtsDataModule
|
from tts_datamodule import LJSpeechTtsDataModule
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -138,35 +53,16 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=30,
|
default=1000,
|
||||||
help="""It specifies the checkpoint to use for decoding.
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
Note: Epoch counts from 1.
|
Note: Epoch counts from 1.
|
||||||
You can specify --avg to use more checkpoints for model averaging.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--iter",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="""If positive, --epoch is ignored and it
|
|
||||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
|
||||||
You can specify --avg to use more checkpoints for model averaging.
|
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--avg",
|
|
||||||
type=int,
|
|
||||||
default=15,
|
|
||||||
help="Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
|
||||||
"'--epoch' and '--iter'",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="zipformer/exp",
|
default="vits/exp",
|
||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -174,7 +70,7 @@ def get_parser():
|
|||||||
"--tokens",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/tokens.txt",
|
default="data/tokens.txt",
|
||||||
help="""Path to tokens.txt.""",
|
help="""Path to vocabulary.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -185,8 +81,9 @@ def infer_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> None:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dl:
|
dl:
|
||||||
@ -195,20 +92,8 @@ def infer_dataset(
|
|||||||
It is returned by :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The neural model.
|
The neural model.
|
||||||
sp:
|
tokenizer:
|
||||||
The BPE model.
|
Used to convert text to phonemes.
|
||||||
word_table:
|
|
||||||
The word symbol table.
|
|
||||||
decoding_graph:
|
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
|
||||||
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
|
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
|
||||||
Returns:
|
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
|
||||||
Its value is a list of tuples. Each tuple contains two elements:
|
|
||||||
The first is the reference transcript, and the second is the
|
|
||||||
predicted result.
|
|
||||||
"""
|
"""
|
||||||
# Background worker save audios to disk.
|
# Background worker save audios to disk.
|
||||||
def _save_worker(
|
def _save_worker(
|
||||||
@ -233,7 +118,7 @@ def infer_dataset(
|
|||||||
|
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
log_interval = 10
|
log_interval = 5
|
||||||
|
|
||||||
try:
|
try:
|
||||||
num_batches = len(dl)
|
num_batches = len(dl)
|
||||||
@ -242,7 +127,6 @@ def infer_dataset(
|
|||||||
|
|
||||||
futures = []
|
futures = []
|
||||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
# We only want one background worker so that serialization is deterministic.
|
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
batch_size = len(batch["text"])
|
batch_size = len(batch["text"])
|
||||||
|
|
||||||
@ -253,7 +137,7 @@ def infer_dataset(
|
|||||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||||
tokens = tokens.to(device)
|
tokens = tokens.to(device)
|
||||||
tokens_lens = tokens_lens.to(device)
|
tokens_lens = tokens_lens.to(device)
|
||||||
# a tensor of shape (B, T)
|
# tensor of shape (B, T)
|
||||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
||||||
|
|
||||||
audio = batch["audio"]
|
audio = batch["audio"]
|
||||||
@ -265,9 +149,6 @@ def infer_dataset(
|
|||||||
# convert to samples
|
# convert to samples
|
||||||
audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
|
audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
|
||||||
|
|
||||||
# import pdb
|
|
||||||
# pdb.set_trace()
|
|
||||||
|
|
||||||
futures.append(
|
futures.append(
|
||||||
executor.submit(
|
executor.submit(
|
||||||
_save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred
|
_save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred
|
||||||
@ -295,10 +176,7 @@ def main():
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
if params.iter > 0:
|
params.suffix = f"epoch-{params.epoch}"
|
||||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
|
||||||
else:
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
|
||||||
|
|
||||||
params.res_dir = params.exp_dir / "infer" / params.suffix
|
params.res_dir = params.exp_dir / "infer" / params.suffix
|
||||||
params.save_wav_dir = params.res_dir / "wav"
|
params.save_wav_dir = params.res_dir / "wav"
|
||||||
@ -322,40 +200,16 @@ def main():
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
|
|
||||||
if params.iter > 0:
|
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
|
||||||
: params.avg
|
|
||||||
]
|
|
||||||
if len(filenames) == 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"No checkpoints found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
elif len(filenames) < params.avg:
|
|
||||||
raise ValueError(
|
|
||||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
logging.info(f"averaging {filenames}")
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
|
||||||
elif params.avg == 1:
|
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
else:
|
|
||||||
start = params.epoch - params.avg + 1
|
|
||||||
filenames = []
|
|
||||||
for i in range(start, params.epoch + 1):
|
|
||||||
if i >= 1:
|
|
||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
|
||||||
logging.info(f"averaging {filenames}")
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param_g = sum([p.numel() for p in model.generator.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of parameters in generator: {num_param_g}")
|
||||||
|
num_param_d = sum([p.numel() for p in model.discriminator.parameters()])
|
||||||
|
logging.info(f"Number of parameters in discriminator: {num_param_d}")
|
||||||
|
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
|
||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
@ -371,17 +225,8 @@ def main():
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# save_results(
|
|
||||||
# params=params,
|
|
||||||
# test_set_name=test_set,
|
|
||||||
# results_dict=results_dict,
|
|
||||||
# )
|
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
# torch.set_num_threads(1)
|
|
||||||
# torch.set_num_interop_threads(1)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
|
||||||
|
|
||||||
# Copyright 2021 Tomoki Hayashi
|
# Copyright 2021 Tomoki Hayashi
|
||||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
@ -9,7 +9,7 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributions as D
|
import torch.distributions as D
|
||||||
@ -266,7 +266,7 @@ class MelSpectrogramLoss(torch.nn.Module):
|
|||||||
return mel_loss
|
return mel_loss
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py
|
||||||
|
|
||||||
"""VITS-related loss modules.
|
"""VITS-related loss modules.
|
||||||
|
|
||||||
|
|||||||
@ -1,534 +0,0 @@
|
|||||||
import copy
|
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
import commons
|
|
||||||
import modules
|
|
||||||
import attentions
|
|
||||||
import monotonic_align
|
|
||||||
|
|
||||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
|
||||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
|
||||||
from commons import init_weights, get_padding
|
|
||||||
|
|
||||||
|
|
||||||
class StochasticDurationPredictor(nn.Module):
|
|
||||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
|
|
||||||
super().__init__()
|
|
||||||
filter_channels = in_channels # it needs to be removed from future version.
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.filter_channels = filter_channels
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.p_dropout = p_dropout
|
|
||||||
self.n_flows = n_flows
|
|
||||||
self.gin_channels = gin_channels
|
|
||||||
|
|
||||||
self.log_flow = modules.Log()
|
|
||||||
self.flows = nn.ModuleList()
|
|
||||||
self.flows.append(modules.ElementwiseAffine(2))
|
|
||||||
for i in range(n_flows):
|
|
||||||
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
|
||||||
self.flows.append(modules.Flip())
|
|
||||||
|
|
||||||
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
|
||||||
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
|
||||||
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
|
||||||
self.post_flows = nn.ModuleList()
|
|
||||||
self.post_flows.append(modules.ElementwiseAffine(2))
|
|
||||||
for i in range(4):
|
|
||||||
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
|
||||||
self.post_flows.append(modules.Flip())
|
|
||||||
|
|
||||||
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
|
||||||
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
|
||||||
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
|
||||||
if gin_channels != 0:
|
|
||||||
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
|
||||||
|
|
||||||
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
|
|
||||||
x = torch.detach(x)
|
|
||||||
x = self.pre(x)
|
|
||||||
if g is not None:
|
|
||||||
g = torch.detach(g)
|
|
||||||
x = x + self.cond(g)
|
|
||||||
x = self.convs(x, x_mask)
|
|
||||||
x = self.proj(x) * x_mask
|
|
||||||
|
|
||||||
if not reverse:
|
|
||||||
flows = self.flows
|
|
||||||
assert w is not None
|
|
||||||
|
|
||||||
logdet_tot_q = 0
|
|
||||||
h_w = self.post_pre(w)
|
|
||||||
h_w = self.post_convs(h_w, x_mask)
|
|
||||||
h_w = self.post_proj(h_w) * x_mask
|
|
||||||
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
|
||||||
z_q = e_q
|
|
||||||
for flow in self.post_flows:
|
|
||||||
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
|
||||||
logdet_tot_q += logdet_q
|
|
||||||
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
|
||||||
u = torch.sigmoid(z_u) * x_mask
|
|
||||||
z0 = (w - u) * x_mask
|
|
||||||
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
|
|
||||||
logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
|
|
||||||
|
|
||||||
logdet_tot = 0
|
|
||||||
z0, logdet = self.log_flow(z0, x_mask)
|
|
||||||
logdet_tot += logdet
|
|
||||||
z = torch.cat([z0, z1], 1)
|
|
||||||
for flow in flows:
|
|
||||||
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
|
||||||
logdet_tot = logdet_tot + logdet
|
|
||||||
nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
|
|
||||||
return nll + logq # [b]
|
|
||||||
else:
|
|
||||||
flows = list(reversed(self.flows))
|
|
||||||
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
|
||||||
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
|
||||||
for flow in flows:
|
|
||||||
z = flow(z, x_mask, g=x, reverse=reverse)
|
|
||||||
z0, z1 = torch.split(z, [1, 1], 1)
|
|
||||||
logw = z0
|
|
||||||
return logw
|
|
||||||
|
|
||||||
|
|
||||||
class DurationPredictor(nn.Module):
|
|
||||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.filter_channels = filter_channels
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.p_dropout = p_dropout
|
|
||||||
self.gin_channels = gin_channels
|
|
||||||
|
|
||||||
self.drop = nn.Dropout(p_dropout)
|
|
||||||
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
|
|
||||||
self.norm_1 = modules.LayerNorm(filter_channels)
|
|
||||||
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
|
|
||||||
self.norm_2 = modules.LayerNorm(filter_channels)
|
|
||||||
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
|
||||||
|
|
||||||
if gin_channels != 0:
|
|
||||||
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
|
||||||
|
|
||||||
def forward(self, x, x_mask, g=None):
|
|
||||||
x = torch.detach(x)
|
|
||||||
if g is not None:
|
|
||||||
g = torch.detach(g)
|
|
||||||
x = x + self.cond(g)
|
|
||||||
x = self.conv_1(x * x_mask)
|
|
||||||
x = torch.relu(x)
|
|
||||||
x = self.norm_1(x)
|
|
||||||
x = self.drop(x)
|
|
||||||
x = self.conv_2(x * x_mask)
|
|
||||||
x = torch.relu(x)
|
|
||||||
x = self.norm_2(x)
|
|
||||||
x = self.drop(x)
|
|
||||||
x = self.proj(x * x_mask)
|
|
||||||
return x * x_mask
|
|
||||||
|
|
||||||
|
|
||||||
class TextEncoder(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
n_vocab,
|
|
||||||
out_channels,
|
|
||||||
hidden_channels,
|
|
||||||
filter_channels,
|
|
||||||
n_heads,
|
|
||||||
n_layers,
|
|
||||||
kernel_size,
|
|
||||||
p_dropout):
|
|
||||||
super().__init__()
|
|
||||||
self.n_vocab = n_vocab
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.hidden_channels = hidden_channels
|
|
||||||
self.filter_channels = filter_channels
|
|
||||||
self.n_heads = n_heads
|
|
||||||
self.n_layers = n_layers
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.p_dropout = p_dropout
|
|
||||||
|
|
||||||
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
|
||||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
|
||||||
|
|
||||||
self.encoder = attentions.Encoder(
|
|
||||||
hidden_channels,
|
|
||||||
filter_channels,
|
|
||||||
n_heads,
|
|
||||||
n_layers,
|
|
||||||
kernel_size,
|
|
||||||
p_dropout)
|
|
||||||
self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
|
||||||
|
|
||||||
def forward(self, x, x_lengths):
|
|
||||||
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
|
||||||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
|
||||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
|
||||||
|
|
||||||
x = self.encoder(x * x_mask, x_mask)
|
|
||||||
stats = self.proj(x) * x_mask
|
|
||||||
|
|
||||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
|
||||||
return x, m, logs, x_mask
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualCouplingBlock(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
channels,
|
|
||||||
hidden_channels,
|
|
||||||
kernel_size,
|
|
||||||
dilation_rate,
|
|
||||||
n_layers,
|
|
||||||
n_flows=4,
|
|
||||||
gin_channels=0):
|
|
||||||
super().__init__()
|
|
||||||
self.channels = channels
|
|
||||||
self.hidden_channels = hidden_channels
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.dilation_rate = dilation_rate
|
|
||||||
self.n_layers = n_layers
|
|
||||||
self.n_flows = n_flows
|
|
||||||
self.gin_channels = gin_channels
|
|
||||||
|
|
||||||
self.flows = nn.ModuleList()
|
|
||||||
for i in range(n_flows):
|
|
||||||
self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
|
|
||||||
self.flows.append(modules.Flip())
|
|
||||||
|
|
||||||
def forward(self, x, x_mask, g=None, reverse=False):
|
|
||||||
if not reverse:
|
|
||||||
for flow in self.flows:
|
|
||||||
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
|
||||||
else:
|
|
||||||
for flow in reversed(self.flows):
|
|
||||||
x = flow(x, x_mask, g=g, reverse=reverse)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class PosteriorEncoder(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
hidden_channels,
|
|
||||||
kernel_size,
|
|
||||||
dilation_rate,
|
|
||||||
n_layers,
|
|
||||||
gin_channels=0):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.hidden_channels = hidden_channels
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.dilation_rate = dilation_rate
|
|
||||||
self.n_layers = n_layers
|
|
||||||
self.gin_channels = gin_channels
|
|
||||||
|
|
||||||
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
|
||||||
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
|
||||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
|
||||||
|
|
||||||
def forward(self, x, x_lengths, g=None):
|
|
||||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
|
||||||
x = self.pre(x) * x_mask
|
|
||||||
x = self.enc(x, x_mask, g=g)
|
|
||||||
stats = self.proj(x) * x_mask
|
|
||||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
|
||||||
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
|
||||||
return z, m, logs, x_mask
|
|
||||||
|
|
||||||
|
|
||||||
class Generator(torch.nn.Module):
|
|
||||||
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
|
|
||||||
super(Generator, self).__init__()
|
|
||||||
self.num_kernels = len(resblock_kernel_sizes)
|
|
||||||
self.num_upsamples = len(upsample_rates)
|
|
||||||
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
|
||||||
resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
|
|
||||||
|
|
||||||
self.ups = nn.ModuleList()
|
|
||||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
|
||||||
self.ups.append(weight_norm(
|
|
||||||
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
|
||||||
k, u, padding=(k-u)//2)))
|
|
||||||
|
|
||||||
self.resblocks = nn.ModuleList()
|
|
||||||
for i in range(len(self.ups)):
|
|
||||||
ch = upsample_initial_channel//(2**(i+1))
|
|
||||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
|
||||||
self.resblocks.append(resblock(ch, k, d))
|
|
||||||
|
|
||||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
|
||||||
self.ups.apply(init_weights)
|
|
||||||
|
|
||||||
if gin_channels != 0:
|
|
||||||
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
|
||||||
|
|
||||||
def forward(self, x, g=None):
|
|
||||||
x = self.conv_pre(x)
|
|
||||||
if g is not None:
|
|
||||||
x = x + self.cond(g)
|
|
||||||
|
|
||||||
for i in range(self.num_upsamples):
|
|
||||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
|
||||||
x = self.ups[i](x)
|
|
||||||
xs = None
|
|
||||||
for j in range(self.num_kernels):
|
|
||||||
if xs is None:
|
|
||||||
xs = self.resblocks[i*self.num_kernels+j](x)
|
|
||||||
else:
|
|
||||||
xs += self.resblocks[i*self.num_kernels+j](x)
|
|
||||||
x = xs / self.num_kernels
|
|
||||||
x = F.leaky_relu(x)
|
|
||||||
x = self.conv_post(x)
|
|
||||||
x = torch.tanh(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
|
||||||
print('Removing weight norm...')
|
|
||||||
for l in self.ups:
|
|
||||||
remove_weight_norm(l)
|
|
||||||
for l in self.resblocks:
|
|
||||||
l.remove_weight_norm()
|
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorP(torch.nn.Module):
|
|
||||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
|
||||||
super(DiscriminatorP, self).__init__()
|
|
||||||
self.period = period
|
|
||||||
self.use_spectral_norm = use_spectral_norm
|
|
||||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
|
||||||
self.convs = nn.ModuleList([
|
|
||||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
|
||||||
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
|
||||||
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
|
||||||
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
|
||||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
|
|
||||||
])
|
|
||||||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
fmap = []
|
|
||||||
|
|
||||||
# 1d to 2d
|
|
||||||
b, c, t = x.shape
|
|
||||||
if t % self.period != 0: # pad first
|
|
||||||
n_pad = self.period - (t % self.period)
|
|
||||||
x = F.pad(x, (0, n_pad), "reflect")
|
|
||||||
t = t + n_pad
|
|
||||||
x = x.view(b, c, t // self.period, self.period)
|
|
||||||
|
|
||||||
for l in self.convs:
|
|
||||||
x = l(x)
|
|
||||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
|
||||||
fmap.append(x)
|
|
||||||
x = self.conv_post(x)
|
|
||||||
fmap.append(x)
|
|
||||||
x = torch.flatten(x, 1, -1)
|
|
||||||
|
|
||||||
return x, fmap
|
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorS(torch.nn.Module):
|
|
||||||
def __init__(self, use_spectral_norm=False):
|
|
||||||
super(DiscriminatorS, self).__init__()
|
|
||||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
|
||||||
self.convs = nn.ModuleList([
|
|
||||||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
|
||||||
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
|
||||||
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
|
||||||
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
|
||||||
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
|
||||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
|
||||||
])
|
|
||||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
fmap = []
|
|
||||||
|
|
||||||
for l in self.convs:
|
|
||||||
x = l(x)
|
|
||||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
|
||||||
fmap.append(x)
|
|
||||||
x = self.conv_post(x)
|
|
||||||
fmap.append(x)
|
|
||||||
x = torch.flatten(x, 1, -1)
|
|
||||||
|
|
||||||
return x, fmap
|
|
||||||
|
|
||||||
|
|
||||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
|
||||||
def __init__(self, use_spectral_norm=False):
|
|
||||||
super(MultiPeriodDiscriminator, self).__init__()
|
|
||||||
periods = [2,3,5,7,11]
|
|
||||||
|
|
||||||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
|
||||||
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
|
||||||
self.discriminators = nn.ModuleList(discs)
|
|
||||||
|
|
||||||
def forward(self, y, y_hat):
|
|
||||||
y_d_rs = []
|
|
||||||
y_d_gs = []
|
|
||||||
fmap_rs = []
|
|
||||||
fmap_gs = []
|
|
||||||
for i, d in enumerate(self.discriminators):
|
|
||||||
y_d_r, fmap_r = d(y)
|
|
||||||
y_d_g, fmap_g = d(y_hat)
|
|
||||||
y_d_rs.append(y_d_r)
|
|
||||||
y_d_gs.append(y_d_g)
|
|
||||||
fmap_rs.append(fmap_r)
|
|
||||||
fmap_gs.append(fmap_g)
|
|
||||||
|
|
||||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SynthesizerTrn(nn.Module):
|
|
||||||
"""
|
|
||||||
Synthesizer for Training
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
n_vocab,
|
|
||||||
spec_channels,
|
|
||||||
segment_size,
|
|
||||||
inter_channels,
|
|
||||||
hidden_channels,
|
|
||||||
filter_channels,
|
|
||||||
n_heads,
|
|
||||||
n_layers,
|
|
||||||
kernel_size,
|
|
||||||
p_dropout,
|
|
||||||
resblock,
|
|
||||||
resblock_kernel_sizes,
|
|
||||||
resblock_dilation_sizes,
|
|
||||||
upsample_rates,
|
|
||||||
upsample_initial_channel,
|
|
||||||
upsample_kernel_sizes,
|
|
||||||
n_speakers=0,
|
|
||||||
gin_channels=0,
|
|
||||||
use_sdp=True,
|
|
||||||
**kwargs):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
self.n_vocab = n_vocab
|
|
||||||
self.spec_channels = spec_channels
|
|
||||||
self.inter_channels = inter_channels
|
|
||||||
self.hidden_channels = hidden_channels
|
|
||||||
self.filter_channels = filter_channels
|
|
||||||
self.n_heads = n_heads
|
|
||||||
self.n_layers = n_layers
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.p_dropout = p_dropout
|
|
||||||
self.resblock = resblock
|
|
||||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
|
||||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
|
||||||
self.upsample_rates = upsample_rates
|
|
||||||
self.upsample_initial_channel = upsample_initial_channel
|
|
||||||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
|
||||||
self.segment_size = segment_size
|
|
||||||
self.n_speakers = n_speakers
|
|
||||||
self.gin_channels = gin_channels
|
|
||||||
|
|
||||||
self.use_sdp = use_sdp
|
|
||||||
|
|
||||||
self.enc_p = TextEncoder(n_vocab,
|
|
||||||
inter_channels,
|
|
||||||
hidden_channels,
|
|
||||||
filter_channels,
|
|
||||||
n_heads,
|
|
||||||
n_layers,
|
|
||||||
kernel_size,
|
|
||||||
p_dropout)
|
|
||||||
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
|
||||||
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
|
|
||||||
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
|
||||||
|
|
||||||
if use_sdp:
|
|
||||||
self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
|
|
||||||
else:
|
|
||||||
self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
|
|
||||||
|
|
||||||
if n_speakers > 1:
|
|
||||||
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
|
||||||
|
|
||||||
def forward(self, x, x_lengths, y, y_lengths, sid=None):
|
|
||||||
|
|
||||||
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
|
|
||||||
if self.n_speakers > 0:
|
|
||||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
|
||||||
else:
|
|
||||||
g = None
|
|
||||||
|
|
||||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
|
||||||
z_p = self.flow(z, y_mask, g=g)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
# negative cross-entropy
|
|
||||||
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
|
|
||||||
neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s]
|
|
||||||
neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
|
|
||||||
neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
|
|
||||||
neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
|
|
||||||
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
|
|
||||||
|
|
||||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
|
||||||
attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
|
||||||
|
|
||||||
w = attn.sum(2)
|
|
||||||
if self.use_sdp:
|
|
||||||
l_length = self.dp(x, x_mask, w, g=g)
|
|
||||||
l_length = l_length / torch.sum(x_mask)
|
|
||||||
else:
|
|
||||||
logw_ = torch.log(w + 1e-6) * x_mask
|
|
||||||
logw = self.dp(x, x_mask, g=g)
|
|
||||||
l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging
|
|
||||||
|
|
||||||
# expand prior
|
|
||||||
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
|
|
||||||
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
|
|
||||||
|
|
||||||
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
|
|
||||||
o = self.dec(z_slice, g=g)
|
|
||||||
return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
|
||||||
|
|
||||||
def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
|
|
||||||
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
|
|
||||||
if self.n_speakers > 0:
|
|
||||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
|
||||||
else:
|
|
||||||
g = None
|
|
||||||
|
|
||||||
if self.use_sdp:
|
|
||||||
logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
|
|
||||||
else:
|
|
||||||
logw = self.dp(x, x_mask, g=g)
|
|
||||||
w = torch.exp(logw) * x_mask * length_scale
|
|
||||||
w_ceil = torch.ceil(w)
|
|
||||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
|
||||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
|
|
||||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
|
||||||
attn = commons.generate_path(w_ceil, attn_mask)
|
|
||||||
|
|
||||||
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
|
||||||
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
|
||||||
|
|
||||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
|
||||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
|
||||||
o = self.dec((z * y_mask)[:,:,:max_len], g=g)
|
|
||||||
return o, attn, y_mask, (z, z_p, m_p, logs_p)
|
|
||||||
|
|
||||||
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
|
|
||||||
assert self.n_speakers > 0, "n_speakers have to be larger than 0."
|
|
||||||
g_src = self.emb_g(sid_src).unsqueeze(-1)
|
|
||||||
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
|
|
||||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
|
|
||||||
z_p = self.flow(z, y_mask, g=g_src)
|
|
||||||
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
|
||||||
o_hat = self.dec(z_hat * y_mask, g=g_tgt)
|
|
||||||
return o_hat, y_mask, (z, z_p, z_hat)
|
|
||||||
|
|
||||||
@ -1,4 +1,4 @@
|
|||||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py
|
||||||
|
|
||||||
# Copyright 2021 Tomoki Hayashi
|
# Copyright 2021 Tomoki Hayashi
|
||||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py
|
||||||
|
|
||||||
# Copyright 2021 Tomoki Hayashi
|
# Copyright 2021 Tomoki Hayashi
|
||||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|||||||
@ -1,17 +0,0 @@
|
|||||||
# https://github.com/jaywalnut310/vits/blob/main/text/symbols.py
|
|
||||||
""" from https://github.com/keithito/tacotron """
|
|
||||||
|
|
||||||
'''
|
|
||||||
Defines the set of symbols used in text input to the model.
|
|
||||||
'''
|
|
||||||
_pad = '_'
|
|
||||||
_punctuation = ';:,.!?¡¿—…"«»“” '
|
|
||||||
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
|
||||||
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
|
||||||
|
|
||||||
|
|
||||||
# Export all symbols:
|
|
||||||
symbol_table = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
|
||||||
|
|
||||||
# Special symbol ids
|
|
||||||
SPACE_ID = symbol_table.index(" ")
|
|
||||||
@ -20,6 +20,7 @@
|
|||||||
This code is based on
|
This code is based on
|
||||||
- https://github.com/jaywalnut310/vits
|
- https://github.com/jaywalnut310/vits
|
||||||
- https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py
|
- https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py
|
||||||
|
- https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/conformer.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
@ -67,6 +68,7 @@ class TextEncoder(torch.nn.Module):
|
|||||||
self.emb = torch.nn.Embedding(vocabs, d_model)
|
self.emb = torch.nn.Embedding(vocabs, d_model)
|
||||||
torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5)
|
torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5)
|
||||||
|
|
||||||
|
# We use conformer as text encoder
|
||||||
self.encoder = Transformer(
|
self.encoder = Transformer(
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
|
|||||||
@ -18,7 +18,6 @@ from typing import Dict, List
|
|||||||
|
|
||||||
import g2p_en
|
import g2p_en
|
||||||
import tacotron_cleaner.cleaners
|
import tacotron_cleaner.cleaners
|
||||||
|
|
||||||
from utils import intersperse
|
from utils import intersperse
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
# Wei Kang,
|
|
||||||
# Mingshuang Luo,
|
|
||||||
# Zengwei Yao,
|
|
||||||
# Daniel Povey)
|
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -22,9 +18,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
@ -36,26 +33,17 @@ from torch.optim import Optimizer
|
|||||||
from torch.cuda.amp import GradScaler, autocast
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from tts_datamodule import LJSpeechTtsDataModule
|
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||||
AttributeDict,
|
|
||||||
setup_logger,
|
|
||||||
str2bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
from tokenizer import Tokenizer
|
from tokenizer import Tokenizer
|
||||||
from utils import (
|
from tts_datamodule import LJSpeechTtsDataModule
|
||||||
MetricsTracker,
|
from utils import MetricsTracker, plot_feature, save_checkpoint
|
||||||
plot_feature,
|
|
||||||
save_checkpoint,
|
|
||||||
save_checkpoint_with_global_batch_idx,
|
|
||||||
)
|
|
||||||
from vits import VITS
|
from vits import VITS
|
||||||
|
|
||||||
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
||||||
@ -90,7 +78,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-epochs",
|
"--num-epochs",
|
||||||
type=int,
|
type=int,
|
||||||
default=30,
|
default=1000,
|
||||||
help="Number of epochs to train.",
|
help="Number of epochs to train.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -104,15 +92,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--start-batch",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="""If positive, --start-epoch is ignored and
|
|
||||||
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -127,7 +106,7 @@ def get_parser():
|
|||||||
"--tokens",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/tokens.txt",
|
default="data/tokens.txt",
|
||||||
help="""Path to tokens.txt.""",
|
help="""Path to vocabulary.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -158,24 +137,11 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-every-n",
|
"--save-every-n",
|
||||||
type=int,
|
type=int,
|
||||||
default=4000,
|
default=20,
|
||||||
help="""Save checkpoint after processing this number of batches"
|
help="""Save checkpoint after processing this number of epochs"
|
||||||
periodically. We save checkpoint to exp-dir/ whenever
|
periodically. We save checkpoint to exp-dir/ whenever
|
||||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
params.cur_epoch % save_every_n == 0. The checkpoint filename
|
||||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'
|
||||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
|
||||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--keep-last-k",
|
|
||||||
type=int,
|
|
||||||
default=30,
|
|
||||||
help="""Only keep this number of checkpoints on disk.
|
|
||||||
For instance, if it is 3, there are only 3 checkpoints
|
|
||||||
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
|
||||||
It does not affect checkpoints with name `epoch-xxx.pt`.
|
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -218,8 +184,6 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||||
|
|
||||||
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
|
||||||
|
|
||||||
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||||
|
|
||||||
- feature_dim: The model input dim. It has to match the one used
|
- feature_dim: The model input dim. It has to match the one used
|
||||||
@ -242,18 +206,14 @@ def get_params() -> AttributeDict:
|
|||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": -1, # 0
|
"batch_idx_train": -1, # 0
|
||||||
"log_interval": 10,
|
"log_interval": 50,
|
||||||
"draw_interval": 500,
|
|
||||||
# "reset_interval": 200,
|
|
||||||
"valid_interval": 200,
|
"valid_interval": 200,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
"sampling_rate": 22050,
|
"sampling_rate": 22050,
|
||||||
"frame_shift": 256,
|
"frame_shift": 256,
|
||||||
"frame_length": 1024,
|
"frame_length": 1024,
|
||||||
"feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
|
"feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
|
||||||
"mel_loss_params": {
|
|
||||||
"n_mels": 80,
|
"n_mels": 80,
|
||||||
},
|
|
||||||
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
|
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
|
||||||
"lambda_mel": 45.0, # loss scaling coefficient for Mel loss
|
"lambda_mel": 45.0, # loss scaling coefficient for Mel loss
|
||||||
"lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss
|
"lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss
|
||||||
@ -270,9 +230,7 @@ def load_checkpoint_if_available(
|
|||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""Load checkpoint from file.
|
"""Load checkpoint from file.
|
||||||
|
|
||||||
If params.start_batch is positive, it will load the checkpoint from
|
If params.start_epoch is larger than 1, it will load the checkpoint from
|
||||||
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
|
||||||
params.start_epoch is larger than 1, it will load the checkpoint from
|
|
||||||
`params.start_epoch - 1`.
|
`params.start_epoch - 1`.
|
||||||
|
|
||||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||||
@ -287,9 +245,7 @@ def load_checkpoint_if_available(
|
|||||||
Returns:
|
Returns:
|
||||||
Return a dict containing previously saved training info.
|
Return a dict containing previously saved training info.
|
||||||
"""
|
"""
|
||||||
if params.start_batch > 0:
|
if params.start_epoch > 1:
|
||||||
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
|
||||||
elif params.start_epoch > 1:
|
|
||||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@ -308,19 +264,15 @@ def load_checkpoint_if_available(
|
|||||||
for k in keys:
|
for k in keys:
|
||||||
params[k] = saved_params[k]
|
params[k] = saved_params[k]
|
||||||
|
|
||||||
if params.start_batch > 0:
|
|
||||||
if "cur_epoch" in saved_params:
|
|
||||||
params["start_epoch"] = saved_params["cur_epoch"]
|
|
||||||
|
|
||||||
return saved_params
|
return saved_params
|
||||||
|
|
||||||
|
|
||||||
def get_model(params: AttributeDict) -> nn.Module:
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
mel_loss_params = params.mel_loss_params
|
mel_loss_params = {
|
||||||
mel_loss_params.update(
|
"n_mels": params.n_mels,
|
||||||
frame_length=params.frame_length,
|
"frame_length": params.frame_length,
|
||||||
frame_shift=params.frame_shift,
|
"frame_shift": params.frame_shift,
|
||||||
)
|
}
|
||||||
model = VITS(
|
model = VITS(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
feature_dim=params.feature_dim,
|
feature_dim=params.feature_dim,
|
||||||
@ -381,18 +333,22 @@ def train_one_epoch(
|
|||||||
It is returned by :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The model for training.
|
The model for training.
|
||||||
optimizer:
|
tokenizer:
|
||||||
The optimizer we are using.
|
Used to convert text to phonemes.
|
||||||
scheduler:
|
optimizer_g:
|
||||||
The learning rate scheduler, we call step() every step.
|
The optimizer for generator.
|
||||||
|
optimizer_d:
|
||||||
|
The optimizer for discriminator.
|
||||||
|
scheduler_g:
|
||||||
|
The learning rate scheduler for generator, we call step() every epoch.
|
||||||
|
scheduler_d:
|
||||||
|
The learning rate scheduler for discriminator, we call step() every epoch.
|
||||||
train_dl:
|
train_dl:
|
||||||
Dataloader for the training dataset.
|
Dataloader for the training dataset.
|
||||||
valid_dl:
|
valid_dl:
|
||||||
Dataloader for the validation dataset.
|
Dataloader for the validation dataset.
|
||||||
scaler:
|
scaler:
|
||||||
The scaler used for mix precision training.
|
The scaler used for mix precision training.
|
||||||
model_avg:
|
|
||||||
The stored model averaged from the start of training.
|
|
||||||
tb_writer:
|
tb_writer:
|
||||||
Writer to write log messages to tensorboard.
|
Writer to write log messages to tensorboard.
|
||||||
world_size:
|
world_size:
|
||||||
@ -404,7 +360,7 @@ def train_one_epoch(
|
|||||||
model.train()
|
model.train()
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
|
|
||||||
# used to summary the stats over iterations
|
# used to summary the stats over iterations in one epoch
|
||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
saved_bad_model = False
|
saved_bad_model = False
|
||||||
@ -433,7 +389,6 @@ def train_one_epoch(
|
|||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
loss_info['samples'] = batch_size
|
loss_info['samples'] = batch_size
|
||||||
|
|
||||||
return_sample = params.batch_idx_train % params.log_interval == 0
|
|
||||||
try:
|
try:
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
# forward discriminator
|
# forward discriminator
|
||||||
@ -463,13 +418,11 @@ def train_one_epoch(
|
|||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
forward_generator=True,
|
forward_generator=True,
|
||||||
return_sample=return_sample,
|
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||||
)
|
)
|
||||||
for k, v in stats_g.items():
|
for k, v in stats_g.items():
|
||||||
if "return_sample" not in k:
|
if "returned_sample" not in k:
|
||||||
loss_info[k] = v * batch_size
|
loss_info[k] = v * batch_size
|
||||||
if return_sample:
|
|
||||||
speech_hat_, speech_, mel_hat_, mel_ = stats_g["return_sample"]
|
|
||||||
# update generator
|
# update generator
|
||||||
optimizer_g.zero_grad()
|
optimizer_g.zero_grad()
|
||||||
scaler.scale(loss_g).backward()
|
scaler.scale(loss_g).backward()
|
||||||
@ -477,7 +430,6 @@ def train_one_epoch(
|
|||||||
scaler.update()
|
scaler.update()
|
||||||
|
|
||||||
# summary stats
|
# summary stats
|
||||||
# tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
|
||||||
tot_loss = tot_loss + loss_info
|
tot_loss = tot_loss + loss_info
|
||||||
except: # noqa
|
except: # noqa
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
@ -486,37 +438,12 @@ def train_one_epoch(
|
|||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
|
|
||||||
if (
|
|
||||||
params.batch_idx_train > 0
|
|
||||||
and params.batch_idx_train % params.save_every_n == 0
|
|
||||||
):
|
|
||||||
save_checkpoint_with_global_batch_idx(
|
|
||||||
out_dir=params.exp_dir,
|
|
||||||
global_batch_idx=params.batch_idx_train,
|
|
||||||
model=model,
|
|
||||||
params=params,
|
|
||||||
optimizer_g=optimizer_g,
|
|
||||||
optimizer_d=optimizer_d,
|
|
||||||
scheduler_g=scheduler_g,
|
|
||||||
scheduler_d=scheduler_d,
|
|
||||||
sampler=train_dl.sampler,
|
|
||||||
scaler=scaler,
|
|
||||||
rank=rank,
|
|
||||||
)
|
|
||||||
remove_checkpoints(
|
|
||||||
out_dir=params.exp_dir,
|
|
||||||
topk=params.keep_last_k,
|
|
||||||
rank=rank,
|
|
||||||
)
|
|
||||||
|
|
||||||
# if batch_idx % 100 == 0 and params.use_fp16:
|
|
||||||
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
||||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||||
# of the grad scaler is configurable, but we can't configure it to have different
|
# of the grad scaler is configurable, but we can't configure it to have different
|
||||||
# behavior depending on the current grad scale.
|
# behavior depending on the current grad scale.
|
||||||
cur_grad_scale = scaler._scale.item()
|
cur_grad_scale = scaler._scale.item()
|
||||||
|
|
||||||
# if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
|
|
||||||
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0):
|
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0):
|
||||||
scaler.update(cur_grad_scale * 2.0)
|
scaler.update(cur_grad_scale * 2.0)
|
||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
@ -530,7 +457,6 @@ def train_one_epoch(
|
|||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# if batch_idx % params.log_interval == 0:
|
|
||||||
if params.batch_idx_train % params.log_interval == 0:
|
if params.batch_idx_train % params.log_interval == 0:
|
||||||
cur_lr_g = max(scheduler_g.get_last_lr())
|
cur_lr_g = max(scheduler_g.get_last_lr())
|
||||||
cur_lr_d = max(scheduler_d.get_last_lr())
|
cur_lr_d = max(scheduler_d.get_last_lr())
|
||||||
@ -561,7 +487,8 @@ def train_one_epoch(
|
|||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||||
)
|
)
|
||||||
if return_sample:
|
if "returned_sample" in stats_g:
|
||||||
|
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
|
||||||
tb_writer.add_audio(
|
tb_writer.add_audio(
|
||||||
"train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate
|
"train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate
|
||||||
)
|
)
|
||||||
@ -575,7 +502,6 @@ def train_one_epoch(
|
|||||||
"train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC'
|
"train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC'
|
||||||
)
|
)
|
||||||
|
|
||||||
# if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
|
||||||
if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
|
if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info, (speech_hat, speech) = compute_validation_loss(
|
valid_info, (speech_hat, speech) = compute_validation_loss(
|
||||||
@ -615,14 +541,14 @@ def compute_validation_loss(
|
|||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> MetricsTracker:
|
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
|
||||||
"""Run the validation process."""
|
"""Run the validation process."""
|
||||||
model.eval()
|
model.eval()
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
|
|
||||||
# used to summary the stats over iterations
|
# used to summary the stats over iterations
|
||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
return_sample = None
|
returned_sample = None
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
@ -667,12 +593,14 @@ def compute_validation_loss(
|
|||||||
# infer for first batch:
|
# infer for first batch:
|
||||||
if batch_idx == 0 and rank == 0:
|
if batch_idx == 0 and rank == 0:
|
||||||
inner_model = model.module if isinstance(model, DDP) else model
|
inner_model = model.module if isinstance(model, DDP) else model
|
||||||
audio_pred, _, duration = inner_model.inference(text=tokens[0, :tokens_lens[0].item()])
|
audio_pred, _, duration = inner_model.inference(
|
||||||
|
text=tokens[0, :tokens_lens[0].item()]
|
||||||
|
)
|
||||||
audio_pred = audio_pred.data.cpu().numpy()
|
audio_pred = audio_pred.data.cpu().numpy()
|
||||||
audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
||||||
assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred))
|
assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred))
|
||||||
audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy()
|
audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy()
|
||||||
return_sample = (audio_pred, audio_gt)
|
returned_sample = (audio_pred, audio_gt)
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
tot_loss.reduce(device)
|
tot_loss.reduce(device)
|
||||||
@ -682,7 +610,7 @@ def compute_validation_loss(
|
|||||||
params.best_valid_epoch = params.cur_epoch
|
params.best_valid_epoch = params.cur_epoch
|
||||||
params.best_valid_loss = loss_value
|
params.best_valid_loss = loss_value
|
||||||
|
|
||||||
return tot_loss, return_sample
|
return tot_loss, returned_sample
|
||||||
|
|
||||||
|
|
||||||
def scan_pessimistic_batches_for_oom(
|
def scan_pessimistic_batches_for_oom(
|
||||||
@ -805,18 +733,10 @@ def run(rank, world_size, args):
|
|||||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||||
|
|
||||||
optimizer_g = torch.optim.AdamW(
|
optimizer_g = torch.optim.AdamW(
|
||||||
generator.parameters(),
|
generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
|
||||||
lr=params.lr,
|
|
||||||
betas=(0.8, 0.99),
|
|
||||||
eps=1e-9,
|
|
||||||
# weight_decay=0,
|
|
||||||
)
|
)
|
||||||
optimizer_d = torch.optim.AdamW(
|
optimizer_d = torch.optim.AdamW(
|
||||||
discriminator.parameters(),
|
discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
|
||||||
lr=params.lr,
|
|
||||||
betas=(0.8, 0.99),
|
|
||||||
eps=1e-9,
|
|
||||||
# weight_decay=0,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
|
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
|
||||||
@ -852,16 +772,8 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
train_cuts = ljspeech.train_cuts()
|
train_cuts = ljspeech.train_cuts()
|
||||||
|
|
||||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
|
||||||
# We only load the sampler's state dict when it loads a checkpoint
|
|
||||||
# saved in the middle of an epoch
|
|
||||||
sampler_state_dict = checkpoints["sampler"]
|
|
||||||
else:
|
|
||||||
sampler_state_dict = None
|
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
|
|
||||||
# You should use ../local/display_manifest_statistics.py to get
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
@ -870,13 +782,10 @@ def run(rank, world_size, args):
|
|||||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
# )
|
# )
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
train_dl = ljspeech.train_dataloaders(
|
train_dl = ljspeech.train_dataloaders(train_cuts)
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_cuts = ljspeech.valid_cuts()
|
valid_cuts = ljspeech.valid_cuts()
|
||||||
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
|
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
|
||||||
@ -902,11 +811,11 @@ def run(rank, world_size, args):
|
|||||||
fix_random_seed(params.seed + epoch - 1)
|
fix_random_seed(params.seed + epoch - 1)
|
||||||
train_dl.sampler.set_epoch(epoch - 1)
|
train_dl.sampler.set_epoch(epoch - 1)
|
||||||
|
|
||||||
|
params.cur_epoch = epoch
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
params.cur_epoch = epoch
|
|
||||||
|
|
||||||
train_one_epoch(
|
train_one_epoch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -927,6 +836,7 @@ def run(rank, world_size, args):
|
|||||||
diagnostic.print_diagnostics()
|
diagnostic.print_diagnostics()
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if epoch % params.save_every_n == 0:
|
||||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
filename=filename,
|
filename=filename,
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py
|
||||||
|
|
||||||
"""Flow-related transformation.
|
"""Flow-related transformation.
|
||||||
|
|
||||||
This code is derived from https://github.com/bayesiains/nflows.
|
This code is derived from https://github.com/bayesiains/nflows.
|
||||||
|
|||||||
@ -1,32 +1,35 @@
|
|||||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
# Copyright 2021 Tomoki Hayashi
|
#
|
||||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
"""Function to get random segments."""
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from phonemizer import phonemize
|
|
||||||
from symbols import symbol_table
|
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from unidecode import unidecode
|
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
|
||||||
def get_random_segments(
|
def get_random_segments(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lengths: torch.Tensor,
|
x_lengths: torch.Tensor,
|
||||||
@ -55,6 +58,7 @@ def get_random_segments(
|
|||||||
return segments, start_idxs
|
return segments, start_idxs
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
|
||||||
def get_segments(
|
def get_segments(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
start_idxs: torch.Tensor,
|
start_idxs: torch.Tensor,
|
||||||
@ -78,195 +82,41 @@ def get_segments(
|
|||||||
return segments
|
return segments
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/espnet/espnet/blob/master/espnet2/torch_utils/device_funcs.py
|
# from https://github.com/jaywalnut310/vit://github.com/jaywalnut310/vits/blob/main/commons.py
|
||||||
def force_gatherable(data, device):
|
|
||||||
"""Change object to gatherable in torch.nn.DataParallel recursively
|
|
||||||
|
|
||||||
The difference from to_device() is changing to torch.Tensor if float or int
|
|
||||||
value is found.
|
|
||||||
|
|
||||||
The restriction to the returned value in DataParallel:
|
|
||||||
The object must be
|
|
||||||
- torch.cuda.Tensor
|
|
||||||
- 1 or more dimension. 0-dimension-tensor sends warning.
|
|
||||||
or a list, tuple, dict.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if isinstance(data, dict):
|
|
||||||
return {k: force_gatherable(v, device) for k, v in data.items()}
|
|
||||||
# DataParallel can't handle NamedTuple well
|
|
||||||
elif isinstance(data, tuple) and type(data) is not tuple:
|
|
||||||
return type(data)(*[force_gatherable(o, device) for o in data])
|
|
||||||
elif isinstance(data, (list, tuple, set)):
|
|
||||||
return type(data)(force_gatherable(v, device) for v in data)
|
|
||||||
elif isinstance(data, np.ndarray):
|
|
||||||
return force_gatherable(torch.from_numpy(data), device)
|
|
||||||
elif isinstance(data, torch.Tensor):
|
|
||||||
if data.dim() == 0:
|
|
||||||
# To 1-dim array
|
|
||||||
data = data[None]
|
|
||||||
return data.to(device)
|
|
||||||
elif isinstance(data, float):
|
|
||||||
return torch.tensor([data], dtype=torch.float, device=device)
|
|
||||||
elif isinstance(data, int):
|
|
||||||
return torch.tensor([data], dtype=torch.long, device=device)
|
|
||||||
elif data is None:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
warnings.warn(f"{type(data)} may not be gatherable by DataParallel")
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
# The following codes are based on https://github.com/jaywalnut310/vits
|
|
||||||
|
|
||||||
# Regular expression matching whitespace:
|
|
||||||
_whitespace_re = re.compile(r'\s+')
|
|
||||||
|
|
||||||
# List of (regular expression, replacement) pairs for abbreviations:
|
|
||||||
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
|
||||||
('mrs', 'misess'),
|
|
||||||
('mr', 'mister'),
|
|
||||||
('dr', 'doctor'),
|
|
||||||
('st', 'saint'),
|
|
||||||
('co', 'company'),
|
|
||||||
('jr', 'junior'),
|
|
||||||
('maj', 'major'),
|
|
||||||
('gen', 'general'),
|
|
||||||
('drs', 'doctors'),
|
|
||||||
('rev', 'reverend'),
|
|
||||||
('lt', 'lieutenant'),
|
|
||||||
('hon', 'honorable'),
|
|
||||||
('sgt', 'sergeant'),
|
|
||||||
('capt', 'captain'),
|
|
||||||
('esq', 'esquire'),
|
|
||||||
('ltd', 'limited'),
|
|
||||||
('col', 'colonel'),
|
|
||||||
('ft', 'fort'),
|
|
||||||
]]
|
|
||||||
|
|
||||||
|
|
||||||
def expand_abbreviations(text):
|
|
||||||
for regex, replacement in _abbreviations:
|
|
||||||
text = re.sub(regex, replacement, text)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def lowercase(text):
|
|
||||||
return text.lower()
|
|
||||||
|
|
||||||
|
|
||||||
def collapse_whitespace(text):
|
|
||||||
return re.sub(_whitespace_re, ' ', text)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_ascii(text):
|
|
||||||
return unidecode(text)
|
|
||||||
|
|
||||||
|
|
||||||
def text_clean(text):
|
|
||||||
'''Pipeline for English text, including abbreviation expansion. + punctuation + stress.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string of phonemes.
|
|
||||||
'''
|
|
||||||
text = convert_to_ascii(text)
|
|
||||||
text = lowercase(text)
|
|
||||||
text = expand_abbreviations(text)
|
|
||||||
phonemes = phonemize(
|
|
||||||
text,
|
|
||||||
language='en-us',
|
|
||||||
backend='espeak',
|
|
||||||
strip=True,
|
|
||||||
preserve_punctuation=True,
|
|
||||||
with_stress=True,
|
|
||||||
)
|
|
||||||
phonemes = collapse_whitespace(phonemes)
|
|
||||||
return phonemes
|
|
||||||
|
|
||||||
|
|
||||||
# Mappings from symbol to numeric ID and vice versa:
|
|
||||||
symbol_to_id = {s: i for i, s in enumerate(symbol_table)}
|
|
||||||
id_to_symbol = {i: s for i, s in enumerate(symbol_table)}
|
|
||||||
|
|
||||||
|
|
||||||
# def text_to_sequence(text: str) -> List[int]:
|
|
||||||
# '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
|
||||||
# '''
|
|
||||||
# cleaned_text = text_clean(text)
|
|
||||||
# sequence = [symbol_to_id[symbol] for symbol in cleaned_text]
|
|
||||||
# return sequence
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# def sequence_to_text(sequence: List[int]) -> str:
|
|
||||||
# '''Converts a sequence of IDs back to a string'''
|
|
||||||
# result = ''.join(id_to_symbol[symbol_id] for symbol_id in sequence)
|
|
||||||
# return result
|
|
||||||
|
|
||||||
|
|
||||||
def intersperse(sequence, item=0):
|
def intersperse(sequence, item=0):
|
||||||
result = [item] * (len(sequence) * 2 + 1)
|
result = [item] * (len(sequence) * 2 + 1)
|
||||||
result[1::2] = sequence
|
result[1::2] = sequence
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def prepare_token_batch(
|
# from https://github.com/jaywalnut310/vits/blob/main/utils.py
|
||||||
texts: List[str],
|
MATPLOTLIB_FLAG = False
|
||||||
phonemes: Optional[List[str]] = None,
|
|
||||||
intersperse_blank: bool = True,
|
|
||||||
blank_id: int = 0,
|
|
||||||
pad_id: int = 0,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Convert a list of text strings into a batch of symbol tokens with padding.
|
|
||||||
Args:
|
|
||||||
texts: list of text strings
|
|
||||||
intersperse_blank: whether to intersperse blank tokens in the converted token sequence.
|
|
||||||
blank_id: index of blank token
|
|
||||||
pad_id: padding index
|
|
||||||
"""
|
|
||||||
if phonemes is None:
|
|
||||||
# normalize text
|
|
||||||
normalized_texts = []
|
|
||||||
for text in texts:
|
|
||||||
text = convert_to_ascii(text)
|
|
||||||
text = lowercase(text)
|
|
||||||
text = expand_abbreviations(text)
|
|
||||||
normalized_texts.append(text)
|
|
||||||
|
|
||||||
# convert to phonemes
|
|
||||||
phonemes = phonemize(
|
|
||||||
normalized_texts,
|
|
||||||
language='en-us',
|
|
||||||
backend='espeak',
|
|
||||||
strip=True,
|
|
||||||
preserve_punctuation=True,
|
|
||||||
with_stress=True,
|
|
||||||
)
|
|
||||||
phonemes = [collapse_whitespace(sequence) for sequence in phonemes]
|
|
||||||
|
|
||||||
# convert to symbol ids
|
def plot_feature(spectrogram):
|
||||||
lengths = []
|
global MATPLOTLIB_FLAG
|
||||||
sequences = []
|
if not MATPLOTLIB_FLAG:
|
||||||
skip = False
|
import matplotlib
|
||||||
for idx, sequence in enumerate(phonemes):
|
matplotlib.use("Agg")
|
||||||
try:
|
MATPLOTLIB_FLAG = True
|
||||||
sequence = [symbol_to_id[symbol] for symbol in sequence]
|
mpl_logger = logging.getLogger('matplotlib')
|
||||||
except Exception:
|
mpl_logger.setLevel(logging.WARNING)
|
||||||
# print(texts[idx])
|
import matplotlib.pylab as plt
|
||||||
# print(normalized_texts[idx])
|
import numpy as np
|
||||||
print(phonemes[idx])
|
|
||||||
skip = True
|
|
||||||
if intersperse_blank:
|
|
||||||
sequence = intersperse(sequence, blank_id)
|
|
||||||
try:
|
|
||||||
sequences.append(torch.tensor(sequence, dtype=torch.int64))
|
|
||||||
except Exception:
|
|
||||||
print(sequence)
|
|
||||||
skip = True
|
|
||||||
lengths.append(len(sequence))
|
|
||||||
|
|
||||||
sequences = pad_sequence(sequences, batch_first=True, padding_value=pad_id)
|
fig, ax = plt.subplots(figsize=(10, 2))
|
||||||
lengths = torch.tensor(lengths, dtype=torch.int64)
|
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
||||||
return sequences, lengths, skip
|
interpolation='none')
|
||||||
|
plt.colorbar(im, ax=ax)
|
||||||
|
plt.xlabel("Frames")
|
||||||
|
plt.ylabel("Channels")
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
fig.canvas.draw()
|
||||||
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||||
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||||
|
plt.close()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class MetricsTracker(collections.defaultdict):
|
class MetricsTracker(collections.defaultdict):
|
||||||
@ -413,106 +263,3 @@ def save_checkpoint(
|
|||||||
checkpoint[k] = v
|
checkpoint[k] = v
|
||||||
|
|
||||||
torch.save(checkpoint, filename)
|
torch.save(checkpoint, filename)
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint_with_global_batch_idx(
|
|
||||||
out_dir: Path,
|
|
||||||
global_batch_idx: int,
|
|
||||||
model: Union[nn.Module, DDP],
|
|
||||||
params: Optional[Dict[str, Any]] = None,
|
|
||||||
optimizer_g: Optional[Optimizer] = None,
|
|
||||||
optimizer_d: Optional[Optimizer] = None,
|
|
||||||
scheduler_g: Optional[LRSchedulerType] = None,
|
|
||||||
scheduler_d: Optional[LRSchedulerType] = None,
|
|
||||||
scaler: Optional[GradScaler] = None,
|
|
||||||
sampler: Optional[CutSampler] = None,
|
|
||||||
rank: int = 0,
|
|
||||||
):
|
|
||||||
"""Save training info after processing given number of batches.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
out_dir:
|
|
||||||
The directory to save the checkpoint.
|
|
||||||
global_batch_idx:
|
|
||||||
The number of batches processed so far from the very start of the
|
|
||||||
training. The saved checkpoint will have the following filename:
|
|
||||||
f'out_dir / checkpoint-{global_batch_idx}.pt'
|
|
||||||
model:
|
|
||||||
The neural network model whose `state_dict` will be saved in the
|
|
||||||
checkpoint.
|
|
||||||
params:
|
|
||||||
A dict of training configurations to be saved.
|
|
||||||
optimizer_g:
|
|
||||||
The optimizer for generator used in the training.
|
|
||||||
Its `state_dict` will be saved.
|
|
||||||
optimizer_d:
|
|
||||||
The optimizer for discriminator used in the training.
|
|
||||||
Its `state_dict` will be saved.
|
|
||||||
scheduler_g:
|
|
||||||
The learning rate scheduler for generator used in the training.
|
|
||||||
Its `state_dict` will be saved.
|
|
||||||
scheduler_d:
|
|
||||||
The learning rate scheduler for discriminator used in the training.
|
|
||||||
Its `state_dict` will be saved.
|
|
||||||
scaler:
|
|
||||||
The scaler used for mix precision training. Its `state_dict` will
|
|
||||||
be saved.
|
|
||||||
sampler:
|
|
||||||
The sampler used in the training dataset.
|
|
||||||
rank:
|
|
||||||
The rank ID used in DDP training of the current node. Set it to 0
|
|
||||||
if DDP is not used.
|
|
||||||
"""
|
|
||||||
out_dir = Path(out_dir)
|
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
|
|
||||||
save_checkpoint(
|
|
||||||
filename=filename,
|
|
||||||
model=model,
|
|
||||||
params=params,
|
|
||||||
optimizer_g=optimizer_g,
|
|
||||||
optimizer_d=optimizer_d,
|
|
||||||
scheduler_g=scheduler_g,
|
|
||||||
scheduler_d=scheduler_d,
|
|
||||||
scaler=scaler,
|
|
||||||
sampler=sampler,
|
|
||||||
rank=rank,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# def plot_feature(feature):
|
|
||||||
# """
|
|
||||||
# Display the feature matrix as an image. Requires matplotlib to be installed.
|
|
||||||
# """
|
|
||||||
# import matplotlib.pyplot as plt
|
|
||||||
#
|
|
||||||
# feature = np.flip(feature.transpose(1, 0), 0)
|
|
||||||
# return plt.matshow(feature)
|
|
||||||
|
|
||||||
MATPLOTLIB_FLAG = False
|
|
||||||
|
|
||||||
|
|
||||||
def plot_feature(spectrogram):
|
|
||||||
global MATPLOTLIB_FLAG
|
|
||||||
if not MATPLOTLIB_FLAG:
|
|
||||||
import matplotlib
|
|
||||||
matplotlib.use("Agg")
|
|
||||||
MATPLOTLIB_FLAG = True
|
|
||||||
mpl_logger = logging.getLogger('matplotlib')
|
|
||||||
mpl_logger.setLevel(logging.WARNING)
|
|
||||||
import matplotlib.pylab as plt
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(10, 2))
|
|
||||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
|
||||||
interpolation='none')
|
|
||||||
plt.colorbar(im, ax=ax)
|
|
||||||
plt.xlabel("Frames")
|
|
||||||
plt.ylabel("Channels")
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
fig.canvas.draw()
|
|
||||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
|
||||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
|
||||||
plt.close()
|
|
||||||
return data
|
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py
|
# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py
|
||||||
|
|
||||||
# Copyright 2021 Tomoki Hayashi
|
# Copyright 2021 Tomoki Hayashi
|
||||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
"""VITS module for GAN-TTS task."""
|
"""VITS module for GAN-TTS task."""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -247,7 +247,7 @@ class VITS(nn.Module):
|
|||||||
spembs: Optional[torch.Tensor] = None,
|
spembs: Optional[torch.Tensor] = None,
|
||||||
lids: Optional[torch.Tensor] = None,
|
lids: Optional[torch.Tensor] = None,
|
||||||
forward_generator: bool = True,
|
forward_generator: bool = True,
|
||||||
) -> Dict[str, Any]:
|
) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||||
"""Perform generator forward.
|
"""Perform generator forward.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -263,12 +263,8 @@ class VITS(nn.Module):
|
|||||||
forward_generator (bool): Whether to forward generator.
|
forward_generator (bool): Whether to forward generator.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]:
|
|
||||||
- loss (Tensor): Loss scalar tensor.
|
- loss (Tensor): Loss scalar tensor.
|
||||||
- stats (Dict[str, float]): Statistics to be monitored.
|
- stats (Dict[str, float]): Statistics to be monitored.
|
||||||
- weight (Tensor): Weight tensor to summarize losses.
|
|
||||||
- optim_idx (int): Optimizer index (0 for G and 1 for D).
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if forward_generator:
|
if forward_generator:
|
||||||
return self._forward_generator(
|
return self._forward_generator(
|
||||||
@ -308,7 +304,7 @@ class VITS(nn.Module):
|
|||||||
sids: Optional[torch.Tensor] = None,
|
sids: Optional[torch.Tensor] = None,
|
||||||
spembs: Optional[torch.Tensor] = None,
|
spembs: Optional[torch.Tensor] = None,
|
||||||
lids: Optional[torch.Tensor] = None,
|
lids: Optional[torch.Tensor] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||||
"""Perform generator forward.
|
"""Perform generator forward.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -323,12 +319,8 @@ class VITS(nn.Module):
|
|||||||
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]:
|
|
||||||
* loss (Tensor): Loss scalar tensor.
|
* loss (Tensor): Loss scalar tensor.
|
||||||
* stats (Dict[str, float]): Statistics to be monitored.
|
* stats (Dict[str, float]): Statistics to be monitored.
|
||||||
* weight (Tensor): Weight tensor to summarize losses.
|
|
||||||
* optim_idx (int): Optimizer index (0 for G and 1 for D).
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# setup
|
# setup
|
||||||
feats = feats.transpose(1, 2)
|
feats = feats.transpose(1, 2)
|
||||||
@ -399,7 +391,7 @@ class VITS(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if return_sample:
|
if return_sample:
|
||||||
stats["return_sample"] = (
|
stats["returned_sample"] = (
|
||||||
speech_hat_[0].data.cpu().numpy(),
|
speech_hat_[0].data.cpu().numpy(),
|
||||||
speech_[0].data.cpu().numpy(),
|
speech_[0].data.cpu().numpy(),
|
||||||
mel_hat_[0].data.cpu().numpy(),
|
mel_hat_[0].data.cpu().numpy(),
|
||||||
@ -423,7 +415,7 @@ class VITS(nn.Module):
|
|||||||
sids: Optional[torch.Tensor] = None,
|
sids: Optional[torch.Tensor] = None,
|
||||||
spembs: Optional[torch.Tensor] = None,
|
spembs: Optional[torch.Tensor] = None,
|
||||||
lids: Optional[torch.Tensor] = None,
|
lids: Optional[torch.Tensor] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||||
"""Perform discriminator forward.
|
"""Perform discriminator forward.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -438,12 +430,8 @@ class VITS(nn.Module):
|
|||||||
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]:
|
|
||||||
* loss (Tensor): Loss scalar tensor.
|
* loss (Tensor): Loss scalar tensor.
|
||||||
* stats (Dict[str, float]): Statistics to be monitored.
|
* stats (Dict[str, float]): Statistics to be monitored.
|
||||||
* weight (Tensor): Weight tensor to summarize losses.
|
|
||||||
* optim_idx (int): Optimizer index (0 for G and 1 for D).
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# setup
|
# setup
|
||||||
feats = feats.transpose(1, 2)
|
feats = feats.transpose(1, 2)
|
||||||
@ -511,8 +499,8 @@ class VITS(nn.Module):
|
|||||||
alpha: float = 1.0,
|
alpha: float = 1.0,
|
||||||
max_len: Optional[int] = None,
|
max_len: Optional[int] = None,
|
||||||
use_teacher_forcing: bool = False,
|
use_teacher_forcing: bool = False,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""Run inference.
|
"""Run inference for single sample.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text (Tensor): Input text index tensor (T_text,).
|
text (Tensor): Input text index tensor (T_text,).
|
||||||
@ -528,11 +516,9 @@ class VITS(nn.Module):
|
|||||||
use_teacher_forcing (bool): Whether to use teacher forcing.
|
use_teacher_forcing (bool): Whether to use teacher forcing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Tensor]:
|
|
||||||
* wav (Tensor): Generated waveform tensor (T_wav,).
|
* wav (Tensor): Generated waveform tensor (T_wav,).
|
||||||
* att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
|
* att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
|
||||||
* duration (Tensor): Predicted duration tensor (T_text,).
|
* duration (Tensor): Predicted duration tensor (T_text,).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# setup
|
# setup
|
||||||
text = text[None]
|
text = text[None]
|
||||||
@ -593,8 +579,8 @@ class VITS(nn.Module):
|
|||||||
alpha: float = 1.0,
|
alpha: float = 1.0,
|
||||||
max_len: Optional[int] = None,
|
max_len: Optional[int] = None,
|
||||||
use_teacher_forcing: bool = False,
|
use_teacher_forcing: bool = False,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""Run inference.
|
"""Run inference for one batch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text (Tensor): Input text index tensor (B, T_text).
|
text (Tensor): Input text index tensor (B, T_text).
|
||||||
@ -605,11 +591,9 @@ class VITS(nn.Module):
|
|||||||
max_len (Optional[int]): Maximum length.
|
max_len (Optional[int]): Maximum length.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Tensor]:
|
|
||||||
* wav (Tensor): Generated waveform tensor (B, T_wav).
|
* wav (Tensor): Generated waveform tensor (B, T_wav).
|
||||||
* att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text).
|
* att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text).
|
||||||
* duration (Tensor): Predicted duration tensor (B, T_text).
|
* duration (Tensor): Predicted duration tensor (B, T_text).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# inference
|
# inference
|
||||||
wav, att_w, dur = self.generator.inference(
|
wav, att_w, dur = self.generator.inference(
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py
|
||||||
|
|
||||||
# Copyright 2021 Tomoki Hayashi
|
# Copyright 2021 Tomoki Hayashi
|
||||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user