minor updates

This commit is contained in:
jinzr 2023-11-29 16:07:44 +08:00
parent 431048a1c7
commit 617721dfc3
19 changed files with 588 additions and 84 deletions

View File

@ -17,19 +17,15 @@
"""
This file reads the texts in given manifest and generate the file that maps tokens to IDs.
This file reads the texts in given manifest and generates the file that maps tokens to IDs.
"""
import argparse
import logging
from collections import Counter
from pathlib import Path
from typing import Dict
import g2p_en
import tacotron_cleaner.cleaners
from lhotse import load_manifest
from tqdm import tqdm
def get_args():
@ -74,35 +70,24 @@ def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
def get_token2id(manifest_file: Path) -> Dict[str, int]:
"""Return a dict that maps token to IDs."""
extra_tokens = {
"<blk>": 0, # blank
"<sos/eos>": 1, # sos and eos symbols.
"<unk>": 2, # OOV
}
cut_set = load_manifest(manifest_file)
g2p = g2p_en.G2p()
counter = Counter()
extra_tokens = [
"<blk>", # 0 for blank
"<sos/eos>", # 1 for sos and eos symbols.
"<unk>", # 2 for OOV
]
all_tokens = set()
for cut in tqdm(cut_set):
cut_set = load_manifest(manifest_file)
for cut in cut_set:
# Each cut only contain one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
text = cut.supervisions[0].text
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
tokens = g2p(text)
for t in tokens:
counter[t] += 1
for t in cut.tokens:
all_tokens.add(t)
# Sort by the number of occurrences in descending order
tokens_and_counts = sorted(counter.items(), key=lambda x: -x[1])
all_tokens = extra_tokens + list(all_tokens)
for token, idx in extra_tokens.items():
tokens_and_counts.insert(idx, (token, None))
token2id: Dict[str, int] = {
token: i for i, (token, count) in enumerate(tokens_and_counts)
}
token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
return token2id

View File

@ -0,0 +1,60 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file reads the texts in given manifest and save the new cuts with phoneme tokens.
"""
import logging
from pathlib import Path
import g2p_en
import tacotron_cleaner.cleaners
from lhotse import CutSet, load_manifest
def prepare_tokens_vctk():
output_dir = Path("data/spectrogram")
prefix = "vctk"
suffix = "jsonl.gz"
partition = "all"
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
g2p = g2p_en.G2p()
new_cuts = []
for cut in cut_set:
# Each cut only contains one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
text = cut.supervisions[0].normalized_text
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
cut.tokens = g2p(text)
new_cuts.append(cut)
new_cut_set = CutSet.from_cuts(new_cuts)
new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
prepare_tokens_vctk()

View File

@ -66,7 +66,17 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Split the VCTK cuts into train, valid and test sets"
log "Stage 3: Prepare phoneme tokens for VCTK"
if [ ! -e data/spectrogram/.vctk_with_token.done ]; then
./local/prepare_tokens_vctk.py
mv data/spectrogram/vctk_cuts_with_tokens_all.jsonl.gz \
data/spectrogram/vctk_cuts_all.jsonl.gz
touch data/spectrogram/.vctk_with_token.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Split the VCTK cuts into train, valid and test sets"
if [ ! -e data/spectrogram/.vctk_split.done ]; then
lhotse subset --last 600 \
data/spectrogram/vctk_cuts_all.jsonl.gz \
@ -88,8 +98,12 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Generate token file"
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate token file"
# We assume you have installed g2p_en and espnet_tts_frontend.
# If not, please install them with:
# - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/tokens.txt ]; then
./local/prepare_token_file.py \
--manifest-file data/spectrogram/vctk_cuts_train.jsonl.gz \

View File

@ -14,6 +14,7 @@ from typing import Optional
import torch
import torch.nn.functional as F
from flow import (
ConvFlow,
DilatedDepthSeparableConv,

267
egs/vctk/TTS/vits/export-onnx.py Executable file
View File

@ -0,0 +1,267 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script exports a VITS model from PyTorch to ONNX.
Export the model to ONNX:
./vits/export-onnx.py \
--epoch 1000 \
--exp-dir vits/exp \
--tokens data/tokens.txt
It will generate two files inside vits/exp:
- vits-epoch-1000.onnx
- vits-epoch-1000.int8.onnx (quantizated model)
See ./test_onnx.py for how to use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import onnx
import torch
import torch.nn as nn
from onnxruntime.quantization import QuantType, quantize_dynamic
from tokenizer import Tokenizer
from train import get_model, get_params
from icefall.checkpoint import load_checkpoint
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=1000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="vits/exp",
help="The experiment dir",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxModel(nn.Module):
"""A wrapper for VITS generator."""
def __init__(self, model: nn.Module):
"""
Args:
model:
A VITS generator.
frame_shift:
The frame shift in samples.
"""
super().__init__()
self.model = model
def forward(
self,
tokens: torch.Tensor,
tokens_lens: torch.Tensor,
noise_scale: float = 0.667,
noise_scale_dur: float = 0.8,
alpha: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of VITS.inference_batch
Args:
tokens:
Input text token indexes (1, T_text)
tokens_lens:
Number of tokens of shape (1,)
noise_scale (float):
Noise scale parameter for flow.
noise_scale_dur (float):
Noise scale parameter for duration predictor.
alpha (float):
Alpha parameter to control the speed of generated speech.
Returns:
Return a tuple containing:
- audio, generated wavform tensor, (B, T_wav)
"""
audio, _, _ = self.model.inference(
text=tokens,
text_lengths=tokens_lens,
noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur,
alpha=alpha,
)
return audio
def export_model_onnx(
model: nn.Module,
model_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given generator model to ONNX format.
The exported model has one input:
- tokens, a tensor of shape (1, T_text); dtype is torch.int64
and it has one output:
- audio, a tensor of shape (1, T'); dtype is torch.float32
Args:
model:
The VITS generator.
model_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
alpha = torch.tensor([1], dtype=torch.float32)
torch.onnx.export(
model,
(tokens, tokens_lens, noise_scale, noise_scale_dur, alpha),
model_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"tokens",
"tokens_lens",
"noise_scale",
"noise_scale_dur",
"alpha",
],
output_names=["audio"],
dynamic_axes={
"tokens": {0: "N", 1: "T"},
"tokens_lens": {0: "N"},
"audio": {0: "N", 1: "T"},
},
)
meta_data = {
"model_type": "VITS",
"version": "1",
"model_author": "k2-fsa",
"comment": "VITS generator",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=model_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model = model.generator
model.to("cpu")
model.eval()
model = OnnxModel(model=model)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"generator parameters: {num_param}")
suffix = f"epoch-{params.epoch}"
opset_version = 13
logging.info("Exporting encoder")
model_filename = params.exp_dir / f"vits-{suffix}.onnx"
export_model_onnx(
model,
model_filename,
opset_version=opset_version,
)
logging.info(f"Exported generator to {model_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
quantize_dynamic(
model_input=model_filename,
model_output=model_filename_int8,
weight_type=QuantType.QUInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -13,6 +13,7 @@ import math
from typing import Optional, Tuple, Union
import torch
from transform import piecewise_rational_quadratic_transform

View File

@ -16,6 +16,9 @@ from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from icefall.utils import make_pad_mask
from duration_predictor import StochasticDurationPredictor
from hifigan import HiFiGANGenerator
from posterior_encoder import PosteriorEncoder
@ -23,8 +26,6 @@ from residual_coupling import ResidualAffineCouplingBlock
from text_encoder import TextEncoder
from utils import get_random_segments
from icefall.utils import make_pad_mask
class VITSGenerator(torch.nn.Module):
"""Generator module in VITS, `Conditional Variational Autoencoder
@ -402,6 +403,7 @@ class VITSGenerator(torch.nn.Module):
"""
# encoder
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
x_mask = x_mask.to(x.dtype)
g = None
if self.spks is not None:
# (B, global_channels, 1)
@ -479,6 +481,7 @@ class VITSGenerator(torch.nn.Module):
dur = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long()
y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device)
y_mask = y_mask.to(x.dtype)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = self._generate_path(dur, attn_mask)

View File

@ -38,7 +38,7 @@ import torch.nn as nn
import torchaudio
from tokenizer import Tokenizer
from train import get_model, get_params
from tts_datamodule import LJSpeechTtsDataModule
from tts_datamodule import VctkTtsDataModule
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger
@ -94,6 +94,7 @@ def infer_dataset(
tokenizer:
Used to convert text to phonemes.
"""
# Background worker save audios to disk.
def _save_worker(
batch_size: int,
@ -127,10 +128,10 @@ def infer_dataset(
futures = []
with ThreadPoolExecutor(max_workers=1) as executor:
for batch_idx, batch in enumerate(dl):
batch_size = len(batch["text"])
batch_size = len(batch["tokens"])
text = batch["text"]
tokens = tokenizer.texts_to_token_ids(text)
tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
@ -180,7 +181,7 @@ def infer_dataset(
@torch.no_grad()
def main():
parser = get_parser()
LJSpeechTtsDataModule.add_arguments(parser)
VctkTtsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -224,7 +225,7 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
ljspeech = LJSpeechTtsDataModule(args)
ljspeech = VctkTtsDataModule(args)
test_cuts = ljspeech.test_cuts()
test_dl = ljspeech.test_dataloaders(test_cuts)
@ -236,6 +237,7 @@ def main():
tokenizer=tokenizer,
)
logging.info(f"Wav files are saved to {params.save_wav_dir}")
logging.info("Done!")

View File

@ -14,6 +14,7 @@ from typing import List, Tuple, Union
import torch
import torch.distributions as D
import torch.nn.functional as F
from lhotse.features.kaldi import Wav2LogFilterBank

View File

@ -12,9 +12,9 @@ This code is based on https://github.com/jaywalnut310/vits.
from typing import Optional, Tuple
import torch
from wavenet import Conv1d, WaveNet
from icefall.utils import make_pad_mask
from wavenet import WaveNet, Conv1d
class PosteriorEncoder(torch.nn.Module):

View File

@ -12,6 +12,7 @@ This code is based on https://github.com/jaywalnut310/vits.
from typing import Optional, Tuple, Union
import torch
from flow import FlipFlow
from wavenet import WaveNet

123
egs/vctk/TTS/vits/test_onnx.py Executable file
View File

@ -0,0 +1,123 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is used to test the exported onnx model by vits/export-onnx.py
Use the onnx model to generate a wav:
./vits/test_onnx.py \
--model-filename vits/exp/vits-epoch-1000.onnx \
--tokens data/tokens.txt
"""
import argparse
import logging
import onnxruntime as ort
import torch
import torchaudio
from tokenizer import Tokenizer
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model-filename",
type=str,
required=True,
help="Path to the onnx model.",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
return parser
class OnnxModel:
def __init__(self, model_filename: str):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
self.session_opts = session_opts
self.model = ort.InferenceSession(
model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
"""
Args:
tokens:
A 1-D tensor of shape (1, T)
Returns:
A tensor of shape (1, T')
"""
noise_scale = torch.tensor([0.667], dtype=torch.float32)
noise_scale_dur = torch.tensor([0.8], dtype=torch.float32)
alpha = torch.tensor([1.0], dtype=torch.float32)
out = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: tokens.numpy(),
self.model.get_inputs()[1].name: tokens_lens.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
self.model.get_inputs()[4].name: alpha.numpy(),
},
)[0]
return torch.from_numpy(out)
def main():
args = get_parser().parse_args()
tokenizer = Tokenizer(args.tokens)
logging.info("About to create onnx model")
model = OnnxModel(args.model_filename)
text = "I went there to see the land, the people and how their system works, end quote."
tokens = tokenizer.texts_to_token_ids([text])
tokens = torch.tensor(tokens) # (1, T)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
audio = model(tokens, tokens_lens) # (1, T')
torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
logging.info("Saved to test_onnx.wav")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -30,7 +30,7 @@ from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from icefall.utils import make_pad_mask
from icefall.utils import is_jit_tracing, make_pad_mask
class TextEncoder(torch.nn.Module):
@ -442,8 +442,20 @@ class RelPositionMultiheadAttention(nn.Module):
"""
(batch_size, num_heads, seq_len, n) = x.shape
if not is_jit_tracing():
assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1"
if is_jit_tracing():
rows = torch.arange(start=seq_len - 1, end=-1, step=-1)
cols = torch.arange(seq_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
x = x.reshape(-1, n)
x = torch.gather(x, dim=1, index=indexes)
x = x.reshape(batch_size, num_heads, seq_len, seq_len)
return x
else:
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)

View File

@ -77,3 +77,32 @@ class Tokenizer(object):
token_ids_list.append(token_ids)
return token_ids_list
def tokens_to_token_ids(
self, tokens_list: List[str], intersperse_blank: bool = True
):
"""
Args:
tokens_list:
A list of token list, each corresponding to one utterance.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
Returns:
Return a list of token id list [utterance][token_id]
"""
token_ids_list = []
for tokens in tokens_list:
token_ids = []
for t in tokens:
if t in self.token2id:
token_ids.append(self.token2id[t])
else:
token_ids.append(self.oov_id)
if intersperse_blank:
token_ids = intersperse(token_ids, self.blank_id)
token_ids_list.append(token_ids)
return token_ids_list

View File

@ -34,7 +34,7 @@ from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import LJSpeechTtsDataModule
from tts_datamodule import VctkTtsDataModule
from utils import MetricsTracker, plot_feature, save_checkpoint
from vits import VITS
@ -294,10 +294,9 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
text = batch["text"]
speakers = batch["speakers"]
tokens = batch["tokens"]
tokens = tokenizer.texts_to_token_ids(text)
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
@ -306,7 +305,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
return audio, audio_lens, features, features_lens, tokens, tokens_lens
def train_one_epoch(
@ -384,16 +383,10 @@ def train_one_epoch(
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["text"])
(
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
speakers,
) = prepare_input(batch, tokenizer, device)
batch_size = len(batch["tokens"])
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
batch, tokenizer, device
)
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
@ -582,7 +575,7 @@ def compute_validation_loss(
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
batch_size = len(batch["text"])
batch_size = len(batch["tokens"])
(
audio,
audio_lens,
@ -810,7 +803,7 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
ljspeech = LJSpeechTtsDataModule(args)
ljspeech = VctkTtsDataModule(args)
train_cuts = ljspeech.train_cuts()
@ -914,7 +907,7 @@ def run(rank, world_size, args):
def main():
parser = get_parser()
LJSpeechTtsDataModule.add_arguments(parser)
VctkTtsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)

View File

@ -52,7 +52,7 @@ class _SeedWorkers:
fix_random_seed(self.seed + worker_id)
class LJSpeechTtsDataModule:
class VctkTtsDataModule:
"""
DataModule for tts experiments.
It assumes there is always one train and valid dataloader,
@ -168,7 +168,8 @@ class LJSpeechTtsDataModule:
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_tokens=False,
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
@ -182,7 +183,8 @@ class LJSpeechTtsDataModule:
use_fft_mag=True,
)
train = SpeechSynthesisDataset(
return_tokens=False,
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts,
)
@ -236,13 +238,15 @@ class LJSpeechTtsDataModule:
use_fft_mag=True,
)
validate = SpeechSynthesisDataset(
return_tokens=False,
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts,
)
else:
validate = SpeechSynthesisDataset(
return_tokens=False,
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
@ -273,13 +277,15 @@ class LJSpeechTtsDataModule:
use_fft_mag=True,
)
test = SpeechSynthesisDataset(
return_tokens=False,
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts,
)
else:
test = SpeechSynthesisDataset(
return_tokens=False,
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)

View File

@ -14,15 +14,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import collections
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.distributed as dist
from lhotse.dataset.sampling.base import CutSampler
from pathlib import Path
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer

View File

@ -9,7 +9,8 @@ from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from generator import VITSGenerator
from torch.cuda.amp import autocast
from hifigan import (
HiFiGANMultiPeriodDiscriminator,
HiFiGANMultiScaleDiscriminator,
@ -24,8 +25,9 @@ from loss import (
KLDivergenceLoss,
MelSpectrogramLoss,
)
from torch.cuda.amp import autocast
from utils import get_segments
from generator import VITSGenerator
AVAILABLE_GENERATERS = {
"vits_generator": VITSGenerator,
@ -570,6 +572,7 @@ class VITS(nn.Module):
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
sids: Optional[torch.Tensor] = None,
durations: Optional[torch.Tensor] = None,
noise_scale: float = 0.667,
noise_scale_dur: float = 0.8,
@ -582,6 +585,7 @@ class VITS(nn.Module):
Args:
text (Tensor): Input text index tensor (B, T_text).
text_lengths (Tensor): Input text index tensor (B,).
sids (Tensor): Speaker index tensor (B,).
noise_scale (float): Noise scale value for flow.
noise_scale_dur (float): Noise scale value for duration predictor.
alpha (float): Alpha parameter to control the speed of generated speech.
@ -596,6 +600,7 @@ class VITS(nn.Module):
wav, att_w, dur = self.generator.inference(
text=text,
text_lengths=text_lengths,
sids=sids,
noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur,
alpha=alpha,

View File

@ -9,8 +9,9 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
"""
import logging
import math
import logging
from typing import Optional, Tuple
import torch