modify training script, clean codes

This commit is contained in:
yaozengwei 2023-11-05 22:47:04 +08:00
parent 8d09f8e6bf
commit 04c6ecbaa1
21 changed files with 179 additions and 1914 deletions

View File

@ -1,80 +0,0 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script split the LJSpeech dataset cuts into three sets:
- training, 12500
- validation, 100
- test, 500
The numbers are from https://arxiv.org/pdf/2106.06103.pdf
Usage example:
python3 ./local/split_subsets.py ./data/spectrogram
"""
import argparse
import logging
import random
from pathlib import Path
from lhotse import load_manifest_lazy
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"manifest_dir",
type=Path,
default=Path("data/spectrogram"),
help="Path to the manifest file",
)
return parser.parse_args()
def main():
args = get_args()
manifest_dir = Path(args.manifest_dir)
prefix = "ljspeech"
suffix = "jsonl.gz"
# all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all.{suffix}")
all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all_phonemized.{suffix}")
cut_ids = list(all_cuts.ids)
random.shuffle(cut_ids)
train_cuts = all_cuts.subset(cut_ids=cut_ids[:12500])
valid_cuts = all_cuts.subset(cut_ids=cut_ids[12500:12500 + 100])
test_cuts = all_cuts.subset(cut_ids=cut_ids[12500 + 100:])
assert len(train_cuts) == 12500, "expected 12500 cuts for training but got len(train_cuts)"
assert len(valid_cuts) == 100, "expected 100 cuts but for validation but got len(valid_cuts)"
assert len(test_cuts) == 500, "expected 500 cuts for test but got len(test_cuts)"
train_cuts.to_file(manifest_dir / f"{prefix}_cuts_train.{suffix}")
valid_cuts.to_file(manifest_dir / f"{prefix}_cuts_valid.{suffix}")
test_cuts.to_file(manifest_dir / f"{prefix}_cuts_test.{suffix}")
logging.info("Splitted into three sets: training (12500), validation (100), and test (500)")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -9,8 +9,7 @@ nj=1
stage=-1
stop_stage=100
# dl_dir=$PWD/download
dl_dir=/star-data/zengwei/download/ljspeech/
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
@ -66,22 +65,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
fi
fi
# if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
# log "Stage 3: Phonemize the transcripts for LJSpeech"
# if [ ! -e data/spectrogram/.ljspeech_phonemized.done ]; then
# ./local/phonemize_text.py data/spectrogram
# touch data/spectrogram/.ljspeech_phonemized.done
# fi
# fi
# if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
# log "Stage 4: Split the LJSpeech cuts into three sets"
# if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
# ./local/split_subsets.py data/spectrogram
# touch data/spectrogram/.ljspeech_split.done
# fi
# fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Split the LJSpeech cuts into train, valid and test sets"
if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
@ -94,6 +77,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
lhotse subset --last 500 \
data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \
data/spectrogram/ljspeech_cuts_test.jsonl.gz
rm data/spectrogram/ljspeech_cuts_validtest.jsonl.gz
n=$(( $(gunzip -c data/spectrogram/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 ))

View File

@ -1,161 +0,0 @@
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size*dilation - dilation)/2)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
return kl
def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g
def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
def get_timing_signal_1d(
length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(num_timescales - 1))
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
device = duration.device
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2,3) * mask
return path
def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1. / norm_type)
return total_norm

View File

@ -1,4 +1,4 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)

View File

@ -1,416 +0,0 @@
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
import librosa
import numpy as np
import torch
from torch import nn
from icefall.utils import make_pad_mask
# From https://github.com/espnet/espnet/blob/master/espnet2/layers/stft.py
class Stft(nn.Module):
def __init__(
self,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
):
super().__init__()
self.n_fft = n_fft
if win_length is None:
self.win_length = n_fft
else:
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.normalized = normalized
self.onesided = onesided
if window is not None and not hasattr(torch, f"{window}_window"):
raise ValueError(f"{window} window is not implemented")
self.window = window
def extra_repr(self):
return (
f"n_fft={self.n_fft}, "
f"win_length={self.win_length}, "
f"hop_length={self.hop_length}, "
f"center={self.center}, "
f"normalized={self.normalized}, "
f"onesided={self.onesided}"
)
def forward(
self, input: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""STFT forward function.
Args:
input: (Batch, Nsamples) or (Batch, Nsample, Channels)
ilens: (Batch)
Returns:
output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
"""
bs = input.size(0)
if input.dim() == 3:
multi_channel = True
# input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
input = input.transpose(1, 2).reshape(-1, input.size(1))
else:
multi_channel = False
# NOTE(kamo):
# The default behaviour of torch.stft is compatible with librosa.stft
# about padding and scaling.
# Note that it's different from scipy.signal.stft
# output: (Batch, Freq, Frames, 2=real_imag)
# or (Batch, Channel, Freq, Frames, 2=real_imag)
if self.window is not None:
window_func = getattr(torch, f"{self.window}_window")
window = window_func(
self.win_length, dtype=input.dtype, device=input.device
)
else:
window = None
# For the compatibility of ARM devices, which do not support
# torch.stft() due to the lack of MKL (on older pytorch versions),
# there is an alternative replacement implementation with librosa.
# Note: pytorch >= 1.10.0 now has native support for FFT and STFT
# on all cpu targets including ARM.
if input.is_cuda or torch.backends.mkl.is_available():
stft_kwargs = dict(
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
center=self.center,
window=window,
normalized=self.normalized,
onesided=self.onesided,
)
stft_kwargs["return_complex"] = True
output = torch.stft(input, **stft_kwargs)
output = torch.view_as_real(output)
else:
if self.training:
raise NotImplementedError(
"stft is implemented with librosa on this device, which does not "
"support the training mode."
)
# use stft_kwargs to flexibly control different PyTorch versions' kwargs
# note: librosa does not support a win_length that is < n_ftt
# but the window can be manually padded (see below).
stft_kwargs = dict(
n_fft=self.n_fft,
win_length=self.n_fft,
hop_length=self.hop_length,
center=self.center,
window=window,
pad_mode="reflect",
)
if window is not None:
# pad the given window to n_fft
n_pad_left = (self.n_fft - window.shape[0]) // 2
n_pad_right = self.n_fft - window.shape[0] - n_pad_left
stft_kwargs["window"] = torch.cat(
[torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0
).numpy()
else:
win_length = (
self.win_length if self.win_length is not None else self.n_fft
)
stft_kwargs["window"] = torch.ones(win_length)
output = []
# iterate over istances in a batch
for i, instance in enumerate(input):
stft = librosa.stft(input[i].numpy(), **stft_kwargs)
output.append(torch.tensor(np.stack([stft.real, stft.imag], -1)))
output = torch.stack(output, 0)
if not self.onesided:
len_conj = self.n_fft - output.shape[1]
conj = output[:, 1 : 1 + len_conj].flip(1)
conj[:, :, :, -1].data *= -1
output = torch.cat([output, conj], 1)
if self.normalized:
output = output * (stft_kwargs["window"].shape[0] ** (-0.5))
# output: (Batch, Freq, Frames, 2=real_imag)
# -> (Batch, Frames, Freq, 2=real_imag)
output = output.transpose(1, 2)
if multi_channel:
# output: (Batch * Channel, Frames, Freq, 2=real_imag)
# -> (Batch, Frame, Channel, Freq, 2=real_imag)
output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
1, 2
)
if ilens is not None:
if self.center:
pad = self.n_fft // 2
ilens = ilens + 2 * pad
olens = (
torch.div(ilens - self.n_fft, self.hop_length, rounding_mode="trunc")
+ 1
)
output.masked_fill_(make_pad_mask(olens), 0.0)
else:
olens = None
return output, olens
# From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/linear_spectrogram.py
class LinearSpectrogram(nn.Module):
"""Linear amplitude spectrogram.
Stft -> amplitude-spec
"""
def __init__(
self,
n_fft: int = 1024,
win_length: int = None,
hop_length: int = 256,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
window=window,
center=center,
normalized=normalized,
onesided=onesided,
)
self.n_fft = n_fft
def output_size(self) -> int:
return self.n_fft // 2 + 1
def get_parameters(self) -> Dict[str, Any]:
"""Return the parameters required by Vocoder."""
return dict(
n_fft=self.n_fft,
n_shift=self.hop_length,
win_length=self.win_length,
window=self.window,
)
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Stft: time -> time-freq
input_stft, feats_lens = self.stft(input, input_lengths)
assert input_stft.dim() >= 4, input_stft.shape
# "2" refers to the real/imag parts of Complex
assert input_stft.shape[-1] == 2, input_stft.shape
# STFT -> Power spectrum -> Amp spectrum
# input_stft: (..., F, 2) -> (..., F)
input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10))
return input_amp, feats_lens
# From https://github.com/espnet/espnet/blob/master/espnet2/layers/log_mel.py
class LogMel(nn.Module):
"""Convert STFT to fbank feats
The arguments is same as librosa.filters.mel
Args:
fs: number > 0 [scalar] sampling rate of the incoming signal
n_fft: int > 0 [scalar] number of FFT components
n_mels: int > 0 [scalar] number of Mel bands to generate
fmin: float >= 0 [scalar] lowest frequency (in Hz)
fmax: float >= 0 [scalar] highest frequency (in Hz).
If `None`, use `fmax = fs / 2.0`
htk: use HTK formula instead of Slaney
"""
def __init__(
self,
fs: int = 16000,
n_fft: int = 512,
n_mels: int = 80,
fmin: float = None,
fmax: float = None,
htk: bool = False,
log_base: float = None,
):
super().__init__()
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
_mel_options = dict(
sr=fs,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
htk=htk,
)
self.mel_options = _mel_options
self.log_base = log_base
# Note(kamo): The mel matrix of librosa is different from kaldi.
melmat = librosa.filters.mel(**_mel_options)
# melmat: (D2, D1) -> (D1, D2)
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
def extra_repr(self):
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
def forward(
self,
feat: torch.Tensor,
ilens: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
mel_feat = torch.matmul(feat, self.melmat)
mel_feat = torch.clamp(mel_feat, min=1e-10)
if self.log_base is None:
logmel_feat = mel_feat.log()
elif self.log_base == 2.0:
logmel_feat = mel_feat.log2()
elif self.log_base == 10.0:
logmel_feat = mel_feat.log10()
else:
logmel_feat = mel_feat.log() / torch.log(self.log_base)
# Zero padding
if ilens is not None:
logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens), 0.0)
else:
ilens = feat.new_full(
[feat.size(0)], fill_value=feat.size(1), dtype=torch.long
)
return logmel_feat, ilens
# From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/log_mel_fbank.py
class LogMelFbank(nn.Module):
"""Conventional frontend structure for TTS.
Stft -> amplitude-spec -> Log-Mel-Fbank
"""
def __init__(
self,
fs: int = 16000,
n_fft: int = 1024,
win_length: int = None,
hop_length: int = 256,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
n_mels: int = 80,
fmin: Optional[int] = 80,
fmax: Optional[int] = 7600,
htk: bool = False,
log_base: Optional[float] = 10.0,
):
super().__init__()
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.fmin = fmin
self.fmax = fmax
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
window=window,
center=center,
normalized=normalized,
onesided=onesided,
)
self.logmel = LogMel(
fs=fs,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
htk=htk,
log_base=log_base,
)
def output_size(self) -> int:
return self.n_mels
def get_parameters(self) -> Dict[str, Any]:
"""Return the parameters required by Vocoder"""
return dict(
fs=self.fs,
n_fft=self.n_fft,
n_shift=self.hop_length,
window=self.window,
n_mels=self.n_mels,
win_length=self.win_length,
fmin=self.fmin,
fmax=self.fmax,
)
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Domain-conversion: e.g. Stft: time -> time-freq
input_stft, feats_lens = self.stft(input, input_lengths)
assert input_stft.dim() >= 4, input_stft.shape
# "2" refers to the real/imag parts of Complex
assert input_stft.shape[-1] == 2, input_stft.shape
# NOTE(kamo): We use different definition for log-spec between TTS and ASR
# TTS: log_10(abs(stft))
# ASR: log_e(power(stft))
# input_stft: (..., F, 2) -> (..., F)
input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10))
input_feats, _ = self.logmel(input_amp, feats_lens)
return input_feats, feats_lens

View File

@ -1,4 +1,5 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)

View File

@ -1,4 +1,4 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py
# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)

View File

@ -1,4 +1,4 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -17,118 +16,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script performs model inference on test set.
Usage:
(1) greedy search
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
./vits/infer.py \
--epoch 1000 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
--max-duration 500
"""
import argparse
import logging
import math
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import List
import k2
import torch
import torch.nn as nn
import torchaudio
from train import get_model, get_params, prepare_input
from train import get_model, get_params
from tokenizer import Tokenizer
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger
from tts_datamodule import LJSpeechTtsDataModule
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
@ -138,35 +53,16 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=30,
default=1000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
default="vits/exp",
help="The experiment dir",
)
@ -174,7 +70,7 @@ def get_parser():
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to tokens.txt.""",
help="""Path to vocabulary.""",
)
return parser
@ -185,8 +81,9 @@ def infer_dataset(
params: AttributeDict,
model: nn.Module,
tokenizer: Tokenizer,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
) -> None:
"""Decode dataset.
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
Args:
dl:
@ -195,20 +92,8 @@ def infer_dataset(
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
tokenizer:
Used to convert text to phonemes.
"""
# Background worker save audios to disk.
def _save_worker(
@ -233,7 +118,7 @@ def infer_dataset(
device = next(model.parameters()).device
num_cuts = 0
log_interval = 10
log_interval = 5
try:
num_batches = len(dl)
@ -242,7 +127,6 @@ def infer_dataset(
futures = []
with ThreadPoolExecutor(max_workers=1) as executor:
# We only want one background worker so that serialization is deterministic.
for batch_idx, batch in enumerate(dl):
batch_size = len(batch["text"])
@ -253,7 +137,7 @@ def infer_dataset(
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T)
# tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
audio = batch["audio"]
@ -265,9 +149,6 @@ def infer_dataset(
# convert to samples
audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
# import pdb
# pdb.set_trace()
futures.append(
executor.submit(
_save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred
@ -295,10 +176,7 @@ def main():
params = get_params()
params.update(vars(args))
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.suffix = f"epoch-{params.epoch}"
params.res_dir = params.exp_dir / "infer" / params.suffix
params.save_wav_dir = params.res_dir / "wav"
@ -322,40 +200,16 @@ def main():
logging.info("About to create model")
model = get_model(params)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
num_param_g = sum([p.numel() for p in model.generator.parameters()])
logging.info(f"Number of parameters in generator: {num_param_g}")
num_param_d = sum([p.numel() for p in model.discriminator.parameters()])
logging.info(f"Number of parameters in discriminator: {num_param_d}")
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
# we need cut ids to display recognition results.
args.return_cuts = True
@ -371,17 +225,8 @@ def main():
tokenizer=tokenizer,
)
# save_results(
# params=params,
# test_set_name=test_set,
# results_dict=results_dict,
# )
logging.info("Done!")
# torch.set_num_threads(1)
# torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -1,4 +1,4 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
@ -9,7 +9,7 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
"""
from typing import List, Optional, Tuple, Union
from typing import List, Tuple, Union
import torch
import torch.distributions as D
@ -266,7 +266,7 @@ class MelSpectrogramLoss(torch.nn.Module):
return mel_loss
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py
"""VITS-related loss modules.

View File

@ -1,534 +0,0 @@
import copy
import math
import torch
from torch import nn
from torch.nn import functional as F
import commons
import modules
import attentions
import monotonic_align
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from commons import init_weights, get_padding
class StochasticDurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
super().__init__()
filter_channels = in_channels # it needs to be removed from future version.
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.n_flows = n_flows
self.gin_channels = gin_channels
self.log_flow = modules.Log()
self.flows = nn.ModuleList()
self.flows.append(modules.ElementwiseAffine(2))
for i in range(n_flows):
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4):
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
x = torch.detach(x)
x = self.pre(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.convs(x, x_mask)
x = self.proj(x) * x_mask
if not reverse:
flows = self.flows
assert w is not None
logdet_tot_q = 0
h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = e_q
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
logdet_tot_q += logdet_q
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
z = torch.cat([z0, z1], 1)
for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
return nll + logq # [b]
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1)
logw = z0
return logw
class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
def forward(self, x, x_mask, g=None):
x = torch.detach(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class TextEncoder(nn.Module):
def __init__(self,
n_vocab,
out_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout):
super().__init__()
self.n_vocab = n_vocab
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
self.encoder = attentions.Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout)
self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths):
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return x, m, logs, x_mask
class ResidualCouplingBlock(nn.Module):
def __init__(self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
n_flows=4,
gin_channels=0):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.n_flows = n_flows
self.gin_channels = gin_channels
self.flows = nn.ModuleList()
for i in range(n_flows):
self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=False):
if not reverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, reverse=reverse)
return x
class PosteriorEncoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
return z, m, logs, x_mask
class Generator(torch.nn.Module):
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(weight_norm(
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
k, u, padding=(k-u)//2)))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel//(2**(i+1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.ups.apply(init_weights)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g=None):
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i*self.num_kernels+j](x)
else:
xs += self.resblocks[i*self.num_kernels+j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
])
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
])
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(MultiPeriodDiscriminator, self).__init__()
periods = [2,3,5,7,11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class SynthesizerTrn(nn.Module):
"""
Synthesizer for Training
"""
def __init__(self,
n_vocab,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=0,
use_sdp=True,
**kwargs):
super().__init__()
self.n_vocab = n_vocab
self.spec_channels = spec_channels
self.inter_channels = inter_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.resblock = resblock
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.upsample_rates = upsample_rates
self.upsample_initial_channel = upsample_initial_channel
self.upsample_kernel_sizes = upsample_kernel_sizes
self.segment_size = segment_size
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.use_sdp = use_sdp
self.enc_p = TextEncoder(n_vocab,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout)
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
if use_sdp:
self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
else:
self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
if n_speakers > 1:
self.emb_g = nn.Embedding(n_speakers, gin_channels)
def forward(self, x, x_lengths, y, y_lengths, sid=None):
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
if self.n_speakers > 0:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
g = None
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
z_p = self.flow(z, y_mask, g=g)
with torch.no_grad():
# negative cross-entropy
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s]
neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
w = attn.sum(2)
if self.use_sdp:
l_length = self.dp(x, x_mask, w, g=g)
l_length = l_length / torch.sum(x_mask)
else:
logw_ = torch.log(w + 1e-6) * x_mask
logw = self.dp(x, x_mask, g=g)
l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging
# expand prior
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
o = self.dec(z_slice, g=g)
return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
if self.n_speakers > 0:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
g = None
if self.use_sdp:
logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
else:
logw = self.dp(x, x_mask, g=g)
w = torch.exp(logw) * x_mask * length_scale
w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = commons.generate_path(w_ceil, attn_mask)
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=g, reverse=True)
o = self.dec((z * y_mask)[:,:,:max_len], g=g)
return o, attn, y_mask, (z, z_p, m_p, logs_p)
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
assert self.n_speakers > 0, "n_speakers have to be larger than 0."
g_src = self.emb_g(sid_src).unsqueeze(-1)
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.dec(z_hat * y_mask, g=g_tgt)
return o_hat, y_mask, (z, z_p, z_hat)

View File

@ -1,4 +1,4 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)

View File

@ -1,4 +1,4 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)

View File

@ -1,17 +0,0 @@
# https://github.com/jaywalnut310/vits/blob/main/text/symbols.py
""" from https://github.com/keithito/tacotron """
'''
Defines the set of symbols used in text input to the model.
'''
_pad = '_'
_punctuation = ';:,.!?¡¿—…"«»“” '
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
# Export all symbols:
symbol_table = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
# Special symbol ids
SPACE_ID = symbol_table.index(" ")

View File

@ -20,6 +20,7 @@
This code is based on
- https://github.com/jaywalnut310/vits
- https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py
- https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/conformer.py
"""
import copy
@ -67,6 +68,7 @@ class TextEncoder(torch.nn.Module):
self.emb = torch.nn.Embedding(vocabs, d_model)
torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5)
# We use conformer as text encoder
self.encoder = Transformer(
d_model=d_model,
num_heads=num_heads,

View File

@ -18,7 +18,6 @@ from typing import Dict, List
import g2p_en
import tacotron_cleaner.cleaners
from utils import intersperse

View File

@ -1,9 +1,5 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,
# Zengwei Yao,
# Daniel Povey)
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -22,9 +18,10 @@
import argparse
import logging
import numpy as np
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Tuple, Union
import k2
import torch
@ -36,26 +33,17 @@ from torch.optim import Optimizer
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import LJSpeechTtsDataModule
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import load_checkpoint
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
setup_logger,
str2bool,
)
from icefall.utils import AttributeDict, setup_logger, str2bool
from tokenizer import Tokenizer
from utils import (
MetricsTracker,
plot_feature,
save_checkpoint,
save_checkpoint_with_global_batch_idx,
)
from tts_datamodule import LJSpeechTtsDataModule
from utils import MetricsTracker, plot_feature, save_checkpoint
from vits import VITS
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
@ -90,7 +78,7 @@ def get_parser():
parser.add_argument(
"--num-epochs",
type=int,
default=30,
default=1000,
help="Number of epochs to train.",
)
@ -104,15 +92,6 @@ def get_parser():
""",
)
parser.add_argument(
"--start-batch",
type=int,
default=0,
help="""If positive, --start-epoch is ignored and
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
@ -127,7 +106,7 @@ def get_parser():
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to tokens.txt.""",
help="""Path to vocabulary.""",
)
parser.add_argument(
@ -158,24 +137,11 @@ def get_parser():
parser.add_argument(
"--save-every-n",
type=int,
default=4000,
help="""Save checkpoint after processing this number of batches"
default=20,
help="""Save checkpoint after processing this number of epochs"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 1.
""",
)
parser.add_argument(
"--keep-last-k",
type=int,
default=30,
help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `checkpoint-xxx.pt`.
It does not affect checkpoints with name `epoch-xxx.pt`.
params.cur_epoch % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'
""",
)
@ -218,8 +184,6 @@ def get_params() -> AttributeDict:
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
@ -242,18 +206,14 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": -1, # 0
"log_interval": 10,
"draw_interval": 500,
# "reset_interval": 200,
"log_interval": 50,
"valid_interval": 200,
"env_info": get_env_info(),
"sampling_rate": 22050,
"frame_shift": 256,
"frame_length": 1024,
"feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
"mel_loss_params": {
"n_mels": 80,
},
"n_mels": 80,
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
"lambda_mel": 45.0, # loss scaling coefficient for Mel loss
"lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss
@ -270,9 +230,7 @@ def load_checkpoint_if_available(
) -> Optional[Dict[str, Any]]:
"""Load checkpoint from file.
If params.start_batch is positive, it will load the checkpoint from
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
params.start_epoch is larger than 1, it will load the checkpoint from
If params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model` and `optimizer` it also updates
@ -287,9 +245,7 @@ def load_checkpoint_if_available(
Returns:
Return a dict containing previously saved training info.
"""
if params.start_batch > 0:
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 1:
if params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
@ -308,19 +264,15 @@ def load_checkpoint_if_available(
for k in keys:
params[k] = saved_params[k]
if params.start_batch > 0:
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
return saved_params
def get_model(params: AttributeDict) -> nn.Module:
mel_loss_params = params.mel_loss_params
mel_loss_params.update(
frame_length=params.frame_length,
frame_shift=params.frame_shift,
)
mel_loss_params = {
"n_mels": params.n_mels,
"frame_length": params.frame_length,
"frame_shift": params.frame_shift,
}
model = VITS(
vocab_size=params.vocab_size,
feature_dim=params.feature_dim,
@ -381,18 +333,22 @@ def train_one_epoch(
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
scheduler:
The learning rate scheduler, we call step() every step.
tokenizer:
Used to convert text to phonemes.
optimizer_g:
The optimizer for generator.
optimizer_d:
The optimizer for discriminator.
scheduler_g:
The learning rate scheduler for generator, we call step() every epoch.
scheduler_d:
The learning rate scheduler for discriminator, we call step() every epoch.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
model_avg:
The stored model averaged from the start of training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
@ -404,7 +360,7 @@ def train_one_epoch(
model.train()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
# used to summary the stats over iterations
# used to summary the stats over iterations in one epoch
tot_loss = MetricsTracker()
saved_bad_model = False
@ -433,7 +389,6 @@ def train_one_epoch(
loss_info = MetricsTracker()
loss_info['samples'] = batch_size
return_sample = params.batch_idx_train % params.log_interval == 0
try:
with autocast(enabled=params.use_fp16):
# forward discriminator
@ -463,13 +418,11 @@ def train_one_epoch(
speech=audio,
speech_lengths=audio_lens,
forward_generator=True,
return_sample=return_sample,
return_sample=params.batch_idx_train % params.log_interval == 0,
)
for k, v in stats_g.items():
if "return_sample" not in k:
if "returned_sample" not in k:
loss_info[k] = v * batch_size
if return_sample:
speech_hat_, speech_, mel_hat_, mel_ = stats_g["return_sample"]
# update generator
optimizer_g.zero_grad()
scaler.scale(loss_g).backward()
@ -477,7 +430,6 @@ def train_one_epoch(
scaler.update()
# summary stats
# tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
tot_loss = tot_loss + loss_info
except: # noqa
save_bad_model()
@ -486,37 +438,12 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 5:
return
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
model=model,
params=params,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
)
# if batch_idx % 100 == 0 and params.use_fp16:
if params.batch_idx_train % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
# if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
@ -530,7 +457,6 @@ def train_one_epoch(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
# if batch_idx % params.log_interval == 0:
if params.batch_idx_train % params.log_interval == 0:
cur_lr_g = max(scheduler_g.get_last_lr())
cur_lr_d = max(scheduler_d.get_last_lr())
@ -561,7 +487,8 @@ def train_one_epoch(
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if return_sample:
if "returned_sample" in stats_g:
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
tb_writer.add_audio(
"train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate
)
@ -575,7 +502,6 @@ def train_one_epoch(
"train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC'
)
# if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss")
valid_info, (speech_hat, speech) = compute_validation_loss(
@ -615,14 +541,14 @@ def compute_validation_loss(
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
rank: int = 0,
) -> MetricsTracker:
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
"""Run the validation process."""
model.eval()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
# used to summary the stats over iterations
tot_loss = MetricsTracker()
return_sample = None
returned_sample = None
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
@ -667,12 +593,14 @@ def compute_validation_loss(
# infer for first batch:
if batch_idx == 0 and rank == 0:
inner_model = model.module if isinstance(model, DDP) else model
audio_pred, _, duration = inner_model.inference(text=tokens[0, :tokens_lens[0].item()])
audio_pred, _, duration = inner_model.inference(
text=tokens[0, :tokens_lens[0].item()]
)
audio_pred = audio_pred.data.cpu().numpy()
audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred))
audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy()
return_sample = (audio_pred, audio_gt)
returned_sample = (audio_pred, audio_gt)
if world_size > 1:
tot_loss.reduce(device)
@ -682,7 +610,7 @@ def compute_validation_loss(
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss, return_sample
return tot_loss, returned_sample
def scan_pessimistic_batches_for_oom(
@ -805,18 +733,10 @@ def run(rank, world_size, args):
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer_g = torch.optim.AdamW(
generator.parameters(),
lr=params.lr,
betas=(0.8, 0.99),
eps=1e-9,
# weight_decay=0,
generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
)
optimizer_d = torch.optim.AdamW(
discriminator.parameters(),
lr=params.lr,
betas=(0.8, 0.99),
eps=1e-9,
# weight_decay=0,
discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
@ -852,16 +772,8 @@ def run(rank, world_size, args):
train_cuts = ljspeech.train_cuts()
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch
sampler_state_dict = checkpoints["sampler"]
else:
sampler_state_dict = None
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
@ -870,13 +782,10 @@ def run(rank, world_size, args):
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = ljspeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
train_dl = ljspeech.train_dataloaders(train_cuts)
valid_cuts = ljspeech.valid_cuts()
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
@ -902,11 +811,11 @@ def run(rank, world_size, args):
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
params.cur_epoch = epoch
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
@ -927,27 +836,28 @@ def run(rank, world_size, args):
diagnostic.print_diagnostics()
break
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint(
filename=filename,
params=params,
model=model,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
if rank == 0:
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if epoch % params.save_every_n == 0:
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint(
filename=filename,
params=params,
model=model,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
if rank == 0:
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
# step per epoch
scheduler_g.step()

View File

@ -1,4 +1,5 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py
"""Flow-related transformation.
This code is derived from https://github.com/bayesiains/nflows.

View File

@ -1,32 +1,35 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Function to get random segments."""
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import collections
import logging
import re
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.distributed as dist
from lhotse.dataset.sampling.base import CutSampler
from pathlib import Path
from phonemizer import phonemize
from symbols import symbol_table
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from unidecode import unidecode
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
def get_random_segments(
x: torch.Tensor,
x_lengths: torch.Tensor,
@ -55,6 +58,7 @@ def get_random_segments(
return segments, start_idxs
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
def get_segments(
x: torch.Tensor,
start_idxs: torch.Tensor,
@ -78,195 +82,41 @@ def get_segments(
return segments
# https://github.com/espnet/espnet/blob/master/espnet2/torch_utils/device_funcs.py
def force_gatherable(data, device):
"""Change object to gatherable in torch.nn.DataParallel recursively
The difference from to_device() is changing to torch.Tensor if float or int
value is found.
The restriction to the returned value in DataParallel:
The object must be
- torch.cuda.Tensor
- 1 or more dimension. 0-dimension-tensor sends warning.
or a list, tuple, dict.
"""
if isinstance(data, dict):
return {k: force_gatherable(v, device) for k, v in data.items()}
# DataParallel can't handle NamedTuple well
elif isinstance(data, tuple) and type(data) is not tuple:
return type(data)(*[force_gatherable(o, device) for o in data])
elif isinstance(data, (list, tuple, set)):
return type(data)(force_gatherable(v, device) for v in data)
elif isinstance(data, np.ndarray):
return force_gatherable(torch.from_numpy(data), device)
elif isinstance(data, torch.Tensor):
if data.dim() == 0:
# To 1-dim array
data = data[None]
return data.to(device)
elif isinstance(data, float):
return torch.tensor([data], dtype=torch.float, device=device)
elif isinstance(data, int):
return torch.tensor([data], dtype=torch.long, device=device)
elif data is None:
return None
else:
warnings.warn(f"{type(data)} may not be gatherable by DataParallel")
return data
# The following codes are based on https://github.com/jaywalnut310/vits
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
def text_clean(text):
'''Pipeline for English text, including abbreviation expansion. + punctuation + stress.
Returns:
A string of phonemes.
'''
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_abbreviations(text)
phonemes = phonemize(
text,
language='en-us',
backend='espeak',
strip=True,
preserve_punctuation=True,
with_stress=True,
)
phonemes = collapse_whitespace(phonemes)
return phonemes
# Mappings from symbol to numeric ID and vice versa:
symbol_to_id = {s: i for i, s in enumerate(symbol_table)}
id_to_symbol = {i: s for i, s in enumerate(symbol_table)}
# def text_to_sequence(text: str) -> List[int]:
# '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
# '''
# cleaned_text = text_clean(text)
# sequence = [symbol_to_id[symbol] for symbol in cleaned_text]
# return sequence
#
#
# def sequence_to_text(sequence: List[int]) -> str:
# '''Converts a sequence of IDs back to a string'''
# result = ''.join(id_to_symbol[symbol_id] for symbol_id in sequence)
# return result
# from https://github.com/jaywalnut310/vit://github.com/jaywalnut310/vits/blob/main/commons.py
def intersperse(sequence, item=0):
result = [item] * (len(sequence) * 2 + 1)
result[1::2] = sequence
return result
def prepare_token_batch(
texts: List[str],
phonemes: Optional[List[str]] = None,
intersperse_blank: bool = True,
blank_id: int = 0,
pad_id: int = 0,
) -> torch.Tensor:
"""Convert a list of text strings into a batch of symbol tokens with padding.
Args:
texts: list of text strings
intersperse_blank: whether to intersperse blank tokens in the converted token sequence.
blank_id: index of blank token
pad_id: padding index
"""
if phonemes is None:
# normalize text
normalized_texts = []
for text in texts:
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_abbreviations(text)
normalized_texts.append(text)
# from https://github.com/jaywalnut310/vits/blob/main/utils.py
MATPLOTLIB_FLAG = False
# convert to phonemes
phonemes = phonemize(
normalized_texts,
language='en-us',
backend='espeak',
strip=True,
preserve_punctuation=True,
with_stress=True,
)
phonemes = [collapse_whitespace(sequence) for sequence in phonemes]
# convert to symbol ids
lengths = []
sequences = []
skip = False
for idx, sequence in enumerate(phonemes):
try:
sequence = [symbol_to_id[symbol] for symbol in sequence]
except Exception:
# print(texts[idx])
# print(normalized_texts[idx])
print(phonemes[idx])
skip = True
if intersperse_blank:
sequence = intersperse(sequence, blank_id)
try:
sequences.append(torch.tensor(sequence, dtype=torch.int64))
except Exception:
print(sequence)
skip = True
lengths.append(len(sequence))
def plot_feature(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
sequences = pad_sequence(sequences, batch_first=True, padding_value=pad_id)
lengths = torch.tensor(lengths, dtype=torch.int64)
return sequences, lengths, skip
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
interpolation='none')
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
class MetricsTracker(collections.defaultdict):
@ -413,106 +263,3 @@ def save_checkpoint(
checkpoint[k] = v
torch.save(checkpoint, filename)
def save_checkpoint_with_global_batch_idx(
out_dir: Path,
global_batch_idx: int,
model: Union[nn.Module, DDP],
params: Optional[Dict[str, Any]] = None,
optimizer_g: Optional[Optimizer] = None,
optimizer_d: Optional[Optimizer] = None,
scheduler_g: Optional[LRSchedulerType] = None,
scheduler_d: Optional[LRSchedulerType] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
):
"""Save training info after processing given number of batches.
Args:
out_dir:
The directory to save the checkpoint.
global_batch_idx:
The number of batches processed so far from the very start of the
training. The saved checkpoint will have the following filename:
f'out_dir / checkpoint-{global_batch_idx}.pt'
model:
The neural network model whose `state_dict` will be saved in the
checkpoint.
params:
A dict of training configurations to be saved.
optimizer_g:
The optimizer for generator used in the training.
Its `state_dict` will be saved.
optimizer_d:
The optimizer for discriminator used in the training.
Its `state_dict` will be saved.
scheduler_g:
The learning rate scheduler for generator used in the training.
Its `state_dict` will be saved.
scheduler_d:
The learning rate scheduler for discriminator used in the training.
Its `state_dict` will be saved.
scaler:
The scaler used for mix precision training. Its `state_dict` will
be saved.
sampler:
The sampler used in the training dataset.
rank:
The rank ID used in DDP training of the current node. Set it to 0
if DDP is not used.
"""
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
save_checkpoint(
filename=filename,
model=model,
params=params,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
scaler=scaler,
sampler=sampler,
rank=rank,
)
# def plot_feature(feature):
# """
# Display the feature matrix as an image. Requires matplotlib to be installed.
# """
# import matplotlib.pyplot as plt
#
# feature = np.flip(feature.transpose(1, 0), 0)
# return plt.matshow(feature)
MATPLOTLIB_FLAG = False
def plot_feature(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
interpolation='none')
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data

View File

@ -1,11 +1,11 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py
# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""VITS module for GAN-TTS task."""
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
@ -247,7 +247,7 @@ class VITS(nn.Module):
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
forward_generator: bool = True,
) -> Dict[str, Any]:
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform generator forward.
Args:
@ -263,12 +263,8 @@ class VITS(nn.Module):
forward_generator (bool): Whether to forward generator.
Returns:
Dict[str, Any]:
- loss (Tensor): Loss scalar tensor.
- stats (Dict[str, float]): Statistics to be monitored.
- weight (Tensor): Weight tensor to summarize losses.
- optim_idx (int): Optimizer index (0 for G and 1 for D).
- loss (Tensor): Loss scalar tensor.
- stats (Dict[str, float]): Statistics to be monitored.
"""
if forward_generator:
return self._forward_generator(
@ -308,7 +304,7 @@ class VITS(nn.Module):
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
) -> Dict[str, Any]:
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform generator forward.
Args:
@ -323,12 +319,8 @@ class VITS(nn.Module):
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
Dict[str, Any]:
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
* weight (Tensor): Weight tensor to summarize losses.
* optim_idx (int): Optimizer index (0 for G and 1 for D).
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
"""
# setup
feats = feats.transpose(1, 2)
@ -399,7 +391,7 @@ class VITS(nn.Module):
)
if return_sample:
stats["return_sample"] = (
stats["returned_sample"] = (
speech_hat_[0].data.cpu().numpy(),
speech_[0].data.cpu().numpy(),
mel_hat_[0].data.cpu().numpy(),
@ -423,7 +415,7 @@ class VITS(nn.Module):
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
) -> Dict[str, Any]:
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform discriminator forward.
Args:
@ -438,12 +430,8 @@ class VITS(nn.Module):
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
Dict[str, Any]:
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
* weight (Tensor): Weight tensor to summarize losses.
* optim_idx (int): Optimizer index (0 for G and 1 for D).
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
"""
# setup
feats = feats.transpose(1, 2)
@ -511,8 +499,8 @@ class VITS(nn.Module):
alpha: float = 1.0,
max_len: Optional[int] = None,
use_teacher_forcing: bool = False,
) -> Dict[str, torch.Tensor]:
"""Run inference.
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Run inference for single sample.
Args:
text (Tensor): Input text index tensor (T_text,).
@ -528,11 +516,9 @@ class VITS(nn.Module):
use_teacher_forcing (bool): Whether to use teacher forcing.
Returns:
Dict[str, Tensor]:
* wav (Tensor): Generated waveform tensor (T_wav,).
* att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
* duration (Tensor): Predicted duration tensor (T_text,).
* wav (Tensor): Generated waveform tensor (T_wav,).
* att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
* duration (Tensor): Predicted duration tensor (T_text,).
"""
# setup
text = text[None]
@ -593,8 +579,8 @@ class VITS(nn.Module):
alpha: float = 1.0,
max_len: Optional[int] = None,
use_teacher_forcing: bool = False,
) -> Dict[str, torch.Tensor]:
"""Run inference.
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Run inference for one batch.
Args:
text (Tensor): Input text index tensor (B, T_text).
@ -605,11 +591,9 @@ class VITS(nn.Module):
max_len (Optional[int]): Maximum length.
Returns:
Dict[str, Tensor]:
* wav (Tensor): Generated waveform tensor (B, T_wav).
* att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text).
* duration (Tensor): Predicted duration tensor (B, T_text).
* wav (Tensor): Generated waveform tensor (B, T_wav).
* att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text).
* duration (Tensor): Predicted duration tensor (B, T_text).
"""
# inference
wav, att_w, dur = self.generator.inference(

View File

@ -1,4 +1,4 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)