mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
minor updates
This commit is contained in:
parent
431048a1c7
commit
617721dfc3
@ -17,19 +17,15 @@
|
|||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This file reads the texts in given manifest and generate the file that maps tokens to IDs.
|
This file reads the texts in given manifest and generates the file that maps tokens to IDs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from collections import Counter
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import g2p_en
|
|
||||||
import tacotron_cleaner.cleaners
|
|
||||||
from lhotse import load_manifest
|
from lhotse import load_manifest
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -74,35 +70,24 @@ def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
|||||||
|
|
||||||
def get_token2id(manifest_file: Path) -> Dict[str, int]:
|
def get_token2id(manifest_file: Path) -> Dict[str, int]:
|
||||||
"""Return a dict that maps token to IDs."""
|
"""Return a dict that maps token to IDs."""
|
||||||
extra_tokens = {
|
extra_tokens = [
|
||||||
"<blk>": 0, # blank
|
"<blk>", # 0 for blank
|
||||||
"<sos/eos>": 1, # sos and eos symbols.
|
"<sos/eos>", # 1 for sos and eos symbols.
|
||||||
"<unk>": 2, # OOV
|
"<unk>", # 2 for OOV
|
||||||
}
|
]
|
||||||
cut_set = load_manifest(manifest_file)
|
all_tokens = set()
|
||||||
g2p = g2p_en.G2p()
|
|
||||||
counter = Counter()
|
|
||||||
|
|
||||||
for cut in tqdm(cut_set):
|
cut_set = load_manifest(manifest_file)
|
||||||
|
|
||||||
|
for cut in cut_set:
|
||||||
# Each cut only contain one supervision
|
# Each cut only contain one supervision
|
||||||
assert len(cut.supervisions) == 1, len(cut.supervisions)
|
assert len(cut.supervisions) == 1, len(cut.supervisions)
|
||||||
text = cut.supervisions[0].text
|
for t in cut.tokens:
|
||||||
# Text normalization
|
all_tokens.add(t)
|
||||||
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
|
|
||||||
# Convert to phonemes
|
|
||||||
tokens = g2p(text)
|
|
||||||
for t in tokens:
|
|
||||||
counter[t] += 1
|
|
||||||
|
|
||||||
# Sort by the number of occurrences in descending order
|
all_tokens = extra_tokens + list(all_tokens)
|
||||||
tokens_and_counts = sorted(counter.items(), key=lambda x: -x[1])
|
|
||||||
|
|
||||||
for token, idx in extra_tokens.items():
|
token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
|
||||||
tokens_and_counts.insert(idx, (token, None))
|
|
||||||
|
|
||||||
token2id: Dict[str, int] = {
|
|
||||||
token: i for i, (token, count) in enumerate(tokens_and_counts)
|
|
||||||
}
|
|
||||||
return token2id
|
return token2id
|
||||||
|
|
||||||
|
|
||||||
|
60
egs/vctk/TTS/local/prepare_tokens_vctk.py
Normal file
60
egs/vctk/TTS/local/prepare_tokens_vctk.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao,
|
||||||
|
# Zengrui Jin,)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file reads the texts in given manifest and save the new cuts with phoneme tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import g2p_en
|
||||||
|
import tacotron_cleaner.cleaners
|
||||||
|
from lhotse import CutSet, load_manifest
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_tokens_vctk():
|
||||||
|
output_dir = Path("data/spectrogram")
|
||||||
|
prefix = "vctk"
|
||||||
|
suffix = "jsonl.gz"
|
||||||
|
partition = "all"
|
||||||
|
|
||||||
|
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
|
||||||
|
g2p = g2p_en.G2p()
|
||||||
|
|
||||||
|
new_cuts = []
|
||||||
|
for cut in cut_set:
|
||||||
|
# Each cut only contains one supervision
|
||||||
|
assert len(cut.supervisions) == 1, len(cut.supervisions)
|
||||||
|
text = cut.supervisions[0].normalized_text
|
||||||
|
# Text normalization
|
||||||
|
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
|
||||||
|
# Convert to phonemes
|
||||||
|
cut.tokens = g2p(text)
|
||||||
|
new_cuts.append(cut)
|
||||||
|
|
||||||
|
new_cut_set = CutSet.from_cuts(new_cuts)
|
||||||
|
new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
prepare_tokens_vctk()
|
@ -66,7 +66,17 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
log "Stage 3: Split the VCTK cuts into train, valid and test sets"
|
log "Stage 3: Prepare phoneme tokens for VCTK"
|
||||||
|
if [ ! -e data/spectrogram/.vctk_with_token.done ]; then
|
||||||
|
./local/prepare_tokens_vctk.py
|
||||||
|
mv data/spectrogram/vctk_cuts_with_tokens_all.jsonl.gz \
|
||||||
|
data/spectrogram/vctk_cuts_all.jsonl.gz
|
||||||
|
touch data/spectrogram/.vctk_with_token.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
log "Stage 4: Split the VCTK cuts into train, valid and test sets"
|
||||||
if [ ! -e data/spectrogram/.vctk_split.done ]; then
|
if [ ! -e data/spectrogram/.vctk_split.done ]; then
|
||||||
lhotse subset --last 600 \
|
lhotse subset --last 600 \
|
||||||
data/spectrogram/vctk_cuts_all.jsonl.gz \
|
data/spectrogram/vctk_cuts_all.jsonl.gz \
|
||||||
@ -88,8 +98,12 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
log "Stage 4: Generate token file"
|
log "Stage 5: Generate token file"
|
||||||
|
# We assume you have installed g2p_en and espnet_tts_frontend.
|
||||||
|
# If not, please install them with:
|
||||||
|
# - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
|
||||||
|
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
|
||||||
if [ ! -e data/tokens.txt ]; then
|
if [ ! -e data/tokens.txt ]; then
|
||||||
./local/prepare_token_file.py \
|
./local/prepare_token_file.py \
|
||||||
--manifest-file data/spectrogram/vctk_cuts_train.jsonl.gz \
|
--manifest-file data/spectrogram/vctk_cuts_train.jsonl.gz \
|
||||||
|
@ -14,6 +14,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from flow import (
|
from flow import (
|
||||||
ConvFlow,
|
ConvFlow,
|
||||||
DilatedDepthSeparableConv,
|
DilatedDepthSeparableConv,
|
||||||
|
267
egs/vctk/TTS/vits/export-onnx.py
Executable file
267
egs/vctk/TTS/vits/export-onnx.py
Executable file
@ -0,0 +1,267 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script exports a VITS model from PyTorch to ONNX.
|
||||||
|
|
||||||
|
Export the model to ONNX:
|
||||||
|
./vits/export-onnx.py \
|
||||||
|
--epoch 1000 \
|
||||||
|
--exp-dir vits/exp \
|
||||||
|
--tokens data/tokens.txt
|
||||||
|
|
||||||
|
It will generate two files inside vits/exp:
|
||||||
|
- vits-epoch-1000.onnx
|
||||||
|
- vits-epoch-1000.int8.onnx (quantizated model)
|
||||||
|
|
||||||
|
See ./test_onnx.py for how to use the exported ONNX models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
from train import get_model, get_params
|
||||||
|
|
||||||
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="vits/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
default="data/tokens.txt",
|
||||||
|
help="""Path to vocabulary.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = value
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxModel(nn.Module):
|
||||||
|
"""A wrapper for VITS generator."""
|
||||||
|
|
||||||
|
def __init__(self, model: nn.Module):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
A VITS generator.
|
||||||
|
frame_shift:
|
||||||
|
The frame shift in samples.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
tokens: torch.Tensor,
|
||||||
|
tokens_lens: torch.Tensor,
|
||||||
|
noise_scale: float = 0.667,
|
||||||
|
noise_scale_dur: float = 0.8,
|
||||||
|
alpha: float = 1.0,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Please see the help information of VITS.inference_batch
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens:
|
||||||
|
Input text token indexes (1, T_text)
|
||||||
|
tokens_lens:
|
||||||
|
Number of tokens of shape (1,)
|
||||||
|
noise_scale (float):
|
||||||
|
Noise scale parameter for flow.
|
||||||
|
noise_scale_dur (float):
|
||||||
|
Noise scale parameter for duration predictor.
|
||||||
|
alpha (float):
|
||||||
|
Alpha parameter to control the speed of generated speech.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- audio, generated wavform tensor, (B, T_wav)
|
||||||
|
"""
|
||||||
|
audio, _, _ = self.model.inference(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
noise_scale=noise_scale,
|
||||||
|
noise_scale_dur=noise_scale_dur,
|
||||||
|
alpha=alpha,
|
||||||
|
)
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
def export_model_onnx(
|
||||||
|
model: nn.Module,
|
||||||
|
model_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given generator model to ONNX format.
|
||||||
|
The exported model has one input:
|
||||||
|
|
||||||
|
- tokens, a tensor of shape (1, T_text); dtype is torch.int64
|
||||||
|
|
||||||
|
and it has one output:
|
||||||
|
|
||||||
|
- audio, a tensor of shape (1, T'); dtype is torch.float32
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The VITS generator.
|
||||||
|
model_filename:
|
||||||
|
The filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
|
||||||
|
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
|
||||||
|
noise_scale = torch.tensor([1], dtype=torch.float32)
|
||||||
|
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
|
||||||
|
alpha = torch.tensor([1], dtype=torch.float32)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
(tokens, tokens_lens, noise_scale, noise_scale_dur, alpha),
|
||||||
|
model_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"tokens",
|
||||||
|
"tokens_lens",
|
||||||
|
"noise_scale",
|
||||||
|
"noise_scale_dur",
|
||||||
|
"alpha",
|
||||||
|
],
|
||||||
|
output_names=["audio"],
|
||||||
|
dynamic_axes={
|
||||||
|
"tokens": {0: "N", 1: "T"},
|
||||||
|
"tokens_lens": {0: "N"},
|
||||||
|
"audio": {0: "N", 1: "T"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"model_type": "VITS",
|
||||||
|
"version": "1",
|
||||||
|
"model_author": "k2-fsa",
|
||||||
|
"comment": "VITS generator",
|
||||||
|
}
|
||||||
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
add_meta_data(filename=model_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(params.tokens)
|
||||||
|
params.blank_id = tokenizer.blank_id
|
||||||
|
params.oov_id = tokenizer.oov_id
|
||||||
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
|
||||||
|
model = model.generator
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model = OnnxModel(model=model)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"generator parameters: {num_param}")
|
||||||
|
|
||||||
|
suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
|
opset_version = 13
|
||||||
|
|
||||||
|
logging.info("Exporting encoder")
|
||||||
|
model_filename = params.exp_dir / f"vits-{suffix}.onnx"
|
||||||
|
export_model_onnx(
|
||||||
|
model,
|
||||||
|
model_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported generator to {model_filename}")
|
||||||
|
|
||||||
|
# Generate int8 quantization models
|
||||||
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
logging.info("Generate int8 quantization models")
|
||||||
|
|
||||||
|
model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=model_filename,
|
||||||
|
model_output=model_filename_int8,
|
||||||
|
weight_type=QuantType.QUInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -13,6 +13,7 @@ import math
|
|||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transform import piecewise_rational_quadratic_transform
|
from transform import piecewise_rational_quadratic_transform
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,6 +16,9 @@ from typing import List, Optional, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
from duration_predictor import StochasticDurationPredictor
|
from duration_predictor import StochasticDurationPredictor
|
||||||
from hifigan import HiFiGANGenerator
|
from hifigan import HiFiGANGenerator
|
||||||
from posterior_encoder import PosteriorEncoder
|
from posterior_encoder import PosteriorEncoder
|
||||||
@ -23,8 +26,6 @@ from residual_coupling import ResidualAffineCouplingBlock
|
|||||||
from text_encoder import TextEncoder
|
from text_encoder import TextEncoder
|
||||||
from utils import get_random_segments
|
from utils import get_random_segments
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
|
||||||
|
|
||||||
|
|
||||||
class VITSGenerator(torch.nn.Module):
|
class VITSGenerator(torch.nn.Module):
|
||||||
"""Generator module in VITS, `Conditional Variational Autoencoder
|
"""Generator module in VITS, `Conditional Variational Autoencoder
|
||||||
@ -402,6 +403,7 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
# encoder
|
# encoder
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
|
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
|
||||||
|
x_mask = x_mask.to(x.dtype)
|
||||||
g = None
|
g = None
|
||||||
if self.spks is not None:
|
if self.spks is not None:
|
||||||
# (B, global_channels, 1)
|
# (B, global_channels, 1)
|
||||||
@ -479,6 +481,7 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
dur = torch.ceil(w)
|
dur = torch.ceil(w)
|
||||||
y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long()
|
y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long()
|
||||||
y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device)
|
y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device)
|
||||||
|
y_mask = y_mask.to(x.dtype)
|
||||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||||
attn = self._generate_path(dur, attn_mask)
|
attn = self._generate_path(dur, attn_mask)
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ import torch.nn as nn
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
from tokenizer import Tokenizer
|
from tokenizer import Tokenizer
|
||||||
from train import get_model, get_params
|
from train import get_model, get_params
|
||||||
from tts_datamodule import LJSpeechTtsDataModule
|
from tts_datamodule import VctkTtsDataModule
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.utils import AttributeDict, setup_logger
|
from icefall.utils import AttributeDict, setup_logger
|
||||||
@ -94,6 +94,7 @@ def infer_dataset(
|
|||||||
tokenizer:
|
tokenizer:
|
||||||
Used to convert text to phonemes.
|
Used to convert text to phonemes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Background worker save audios to disk.
|
# Background worker save audios to disk.
|
||||||
def _save_worker(
|
def _save_worker(
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
@ -127,10 +128,10 @@ def infer_dataset(
|
|||||||
futures = []
|
futures = []
|
||||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
batch_size = len(batch["text"])
|
batch_size = len(batch["tokens"])
|
||||||
|
|
||||||
text = batch["text"]
|
tokens = batch["tokens"]
|
||||||
tokens = tokenizer.texts_to_token_ids(text)
|
tokens = tokenizer.tokens_to_token_ids(tokens)
|
||||||
tokens = k2.RaggedTensor(tokens)
|
tokens = k2.RaggedTensor(tokens)
|
||||||
row_splits = tokens.shape.row_splits(1)
|
row_splits = tokens.shape.row_splits(1)
|
||||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||||
@ -180,7 +181,7 @@ def infer_dataset(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LJSpeechTtsDataModule.add_arguments(parser)
|
VctkTtsDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
@ -224,7 +225,7 @@ def main():
|
|||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
ljspeech = LJSpeechTtsDataModule(args)
|
ljspeech = VctkTtsDataModule(args)
|
||||||
|
|
||||||
test_cuts = ljspeech.test_cuts()
|
test_cuts = ljspeech.test_cuts()
|
||||||
test_dl = ljspeech.test_dataloaders(test_cuts)
|
test_dl = ljspeech.test_dataloaders(test_cuts)
|
||||||
@ -236,6 +237,7 @@ def main():
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logging.info(f"Wav files are saved to {params.save_wav_dir}")
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ from typing import List, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributions as D
|
import torch.distributions as D
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from lhotse.features.kaldi import Wav2LogFilterBank
|
from lhotse.features.kaldi import Wav2LogFilterBank
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,9 +12,9 @@ This code is based on https://github.com/jaywalnut310/vits.
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from wavenet import Conv1d, WaveNet
|
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
|
from wavenet import WaveNet, Conv1d
|
||||||
|
|
||||||
|
|
||||||
class PosteriorEncoder(torch.nn.Module):
|
class PosteriorEncoder(torch.nn.Module):
|
||||||
|
@ -12,6 +12,7 @@ This code is based on https://github.com/jaywalnut310/vits.
|
|||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from flow import FlipFlow
|
from flow import FlipFlow
|
||||||
from wavenet import WaveNet
|
from wavenet import WaveNet
|
||||||
|
|
||||||
|
123
egs/vctk/TTS/vits/test_onnx.py
Executable file
123
egs/vctk/TTS/vits/test_onnx.py
Executable file
@ -0,0 +1,123 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script is used to test the exported onnx model by vits/export-onnx.py
|
||||||
|
|
||||||
|
Use the onnx model to generate a wav:
|
||||||
|
./vits/test_onnx.py \
|
||||||
|
--model-filename vits/exp/vits-epoch-1000.onnx \
|
||||||
|
--tokens data/tokens.txt
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the onnx model.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
default="data/tokens.txt",
|
||||||
|
help="""Path to vocabulary.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxModel:
|
||||||
|
def __init__(self, model_filename: str):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 4
|
||||||
|
|
||||||
|
self.session_opts = session_opts
|
||||||
|
|
||||||
|
self.model = ort.InferenceSession(
|
||||||
|
model_filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
|
||||||
|
|
||||||
|
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tokens:
|
||||||
|
A 1-D tensor of shape (1, T)
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (1, T')
|
||||||
|
"""
|
||||||
|
noise_scale = torch.tensor([0.667], dtype=torch.float32)
|
||||||
|
noise_scale_dur = torch.tensor([0.8], dtype=torch.float32)
|
||||||
|
alpha = torch.tensor([1.0], dtype=torch.float32)
|
||||||
|
|
||||||
|
out = self.model.run(
|
||||||
|
[
|
||||||
|
self.model.get_outputs()[0].name,
|
||||||
|
],
|
||||||
|
{
|
||||||
|
self.model.get_inputs()[0].name: tokens.numpy(),
|
||||||
|
self.model.get_inputs()[1].name: tokens_lens.numpy(),
|
||||||
|
self.model.get_inputs()[2].name: noise_scale.numpy(),
|
||||||
|
self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
|
||||||
|
self.model.get_inputs()[4].name: alpha.numpy(),
|
||||||
|
},
|
||||||
|
)[0]
|
||||||
|
return torch.from_numpy(out)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(args.tokens)
|
||||||
|
|
||||||
|
logging.info("About to create onnx model")
|
||||||
|
model = OnnxModel(args.model_filename)
|
||||||
|
|
||||||
|
text = "I went there to see the land, the people and how their system works, end quote."
|
||||||
|
tokens = tokenizer.texts_to_token_ids([text])
|
||||||
|
tokens = torch.tensor(tokens) # (1, T)
|
||||||
|
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
|
||||||
|
audio = model(tokens, tokens_lens) # (1, T')
|
||||||
|
|
||||||
|
torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
|
||||||
|
logging.info("Saved to test_onnx.wav")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -30,7 +30,7 @@ from typing import Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import is_jit_tracing, make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
class TextEncoder(torch.nn.Module):
|
class TextEncoder(torch.nn.Module):
|
||||||
@ -442,18 +442,30 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
(batch_size, num_heads, seq_len, n) = x.shape
|
(batch_size, num_heads, seq_len, n) = x.shape
|
||||||
|
|
||||||
assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1"
|
if not is_jit_tracing():
|
||||||
|
assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1"
|
||||||
|
|
||||||
# Note: TorchScript requires explicit arg for stride()
|
if is_jit_tracing():
|
||||||
batch_stride = x.stride(0)
|
rows = torch.arange(start=seq_len - 1, end=-1, step=-1)
|
||||||
head_stride = x.stride(1)
|
cols = torch.arange(seq_len)
|
||||||
time_stride = x.stride(2)
|
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||||
n_stride = x.stride(3)
|
indexes = rows + cols
|
||||||
return x.as_strided(
|
|
||||||
(batch_size, num_heads, seq_len, seq_len),
|
x = x.reshape(-1, n)
|
||||||
(batch_stride, head_stride, time_stride - n_stride, n_stride),
|
x = torch.gather(x, dim=1, index=indexes)
|
||||||
storage_offset=n_stride * (seq_len - 1),
|
x = x.reshape(batch_size, num_heads, seq_len, seq_len)
|
||||||
)
|
return x
|
||||||
|
else:
|
||||||
|
# Note: TorchScript requires explicit arg for stride()
|
||||||
|
batch_stride = x.stride(0)
|
||||||
|
head_stride = x.stride(1)
|
||||||
|
time_stride = x.stride(2)
|
||||||
|
n_stride = x.stride(3)
|
||||||
|
return x.as_strided(
|
||||||
|
(batch_size, num_heads, seq_len, seq_len),
|
||||||
|
(batch_stride, head_stride, time_stride - n_stride, n_stride),
|
||||||
|
storage_offset=n_stride * (seq_len - 1),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -77,3 +77,32 @@ class Tokenizer(object):
|
|||||||
token_ids_list.append(token_ids)
|
token_ids_list.append(token_ids)
|
||||||
|
|
||||||
return token_ids_list
|
return token_ids_list
|
||||||
|
|
||||||
|
def tokens_to_token_ids(
|
||||||
|
self, tokens_list: List[str], intersperse_blank: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tokens_list:
|
||||||
|
A list of token list, each corresponding to one utterance.
|
||||||
|
intersperse_blank:
|
||||||
|
Whether to intersperse blanks in the token sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a list of token id list [utterance][token_id]
|
||||||
|
"""
|
||||||
|
token_ids_list = []
|
||||||
|
|
||||||
|
for tokens in tokens_list:
|
||||||
|
token_ids = []
|
||||||
|
for t in tokens:
|
||||||
|
if t in self.token2id:
|
||||||
|
token_ids.append(self.token2id[t])
|
||||||
|
else:
|
||||||
|
token_ids.append(self.oov_id)
|
||||||
|
|
||||||
|
if intersperse_blank:
|
||||||
|
token_ids = intersperse(token_ids, self.blank_id)
|
||||||
|
token_ids_list.append(token_ids)
|
||||||
|
|
||||||
|
return token_ids_list
|
||||||
|
@ -34,7 +34,7 @@ from torch.cuda.amp import GradScaler, autocast
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from tts_datamodule import LJSpeechTtsDataModule
|
from tts_datamodule import VctkTtsDataModule
|
||||||
from utils import MetricsTracker, plot_feature, save_checkpoint
|
from utils import MetricsTracker, plot_feature, save_checkpoint
|
||||||
from vits import VITS
|
from vits import VITS
|
||||||
|
|
||||||
@ -294,10 +294,9 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
|||||||
features = batch["features"].to(device)
|
features = batch["features"].to(device)
|
||||||
audio_lens = batch["audio_lens"].to(device)
|
audio_lens = batch["audio_lens"].to(device)
|
||||||
features_lens = batch["features_lens"].to(device)
|
features_lens = batch["features_lens"].to(device)
|
||||||
text = batch["text"]
|
tokens = batch["tokens"]
|
||||||
speakers = batch["speakers"]
|
|
||||||
|
|
||||||
tokens = tokenizer.texts_to_token_ids(text)
|
tokens = tokenizer.tokens_to_token_ids(tokens)
|
||||||
tokens = k2.RaggedTensor(tokens)
|
tokens = k2.RaggedTensor(tokens)
|
||||||
row_splits = tokens.shape.row_splits(1)
|
row_splits = tokens.shape.row_splits(1)
|
||||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||||
@ -306,7 +305,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
|||||||
# a tensor of shape (B, T)
|
# a tensor of shape (B, T)
|
||||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
||||||
|
|
||||||
return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
|
return audio, audio_lens, features, features_lens, tokens, tokens_lens
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
@ -384,16 +383,10 @@ def train_one_epoch(
|
|||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
|
|
||||||
batch_size = len(batch["text"])
|
batch_size = len(batch["tokens"])
|
||||||
(
|
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
||||||
audio,
|
batch, tokenizer, device
|
||||||
audio_lens,
|
)
|
||||||
features,
|
|
||||||
features_lens,
|
|
||||||
tokens,
|
|
||||||
tokens_lens,
|
|
||||||
speakers,
|
|
||||||
) = prepare_input(batch, tokenizer, device)
|
|
||||||
|
|
||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
loss_info["samples"] = batch_size
|
loss_info["samples"] = batch_size
|
||||||
@ -582,7 +575,7 @@ def compute_validation_loss(
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
batch_size = len(batch["text"])
|
batch_size = len(batch["tokens"])
|
||||||
(
|
(
|
||||||
audio,
|
audio,
|
||||||
audio_lens,
|
audio_lens,
|
||||||
@ -810,7 +803,7 @@ def run(rank, world_size, args):
|
|||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
ljspeech = LJSpeechTtsDataModule(args)
|
ljspeech = VctkTtsDataModule(args)
|
||||||
|
|
||||||
train_cuts = ljspeech.train_cuts()
|
train_cuts = ljspeech.train_cuts()
|
||||||
|
|
||||||
@ -914,7 +907,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LJSpeechTtsDataModule.add_arguments(parser)
|
VctkTtsDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ class _SeedWorkers:
|
|||||||
fix_random_seed(self.seed + worker_id)
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
class LJSpeechTtsDataModule:
|
class VctkTtsDataModule:
|
||||||
"""
|
"""
|
||||||
DataModule for tts experiments.
|
DataModule for tts experiments.
|
||||||
It assumes there is always one train and valid dataloader,
|
It assumes there is always one train and valid dataloader,
|
||||||
@ -168,7 +168,8 @@ class LJSpeechTtsDataModule:
|
|||||||
"""
|
"""
|
||||||
logging.info("About to create train dataset")
|
logging.info("About to create train dataset")
|
||||||
train = SpeechSynthesisDataset(
|
train = SpeechSynthesisDataset(
|
||||||
return_tokens=False,
|
return_text=False,
|
||||||
|
return_tokens=True,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
@ -182,7 +183,8 @@ class LJSpeechTtsDataModule:
|
|||||||
use_fft_mag=True,
|
use_fft_mag=True,
|
||||||
)
|
)
|
||||||
train = SpeechSynthesisDataset(
|
train = SpeechSynthesisDataset(
|
||||||
return_tokens=False,
|
return_text=False,
|
||||||
|
return_tokens=True,
|
||||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
@ -236,13 +238,15 @@ class LJSpeechTtsDataModule:
|
|||||||
use_fft_mag=True,
|
use_fft_mag=True,
|
||||||
)
|
)
|
||||||
validate = SpeechSynthesisDataset(
|
validate = SpeechSynthesisDataset(
|
||||||
return_tokens=False,
|
return_text=False,
|
||||||
|
return_tokens=True,
|
||||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
validate = SpeechSynthesisDataset(
|
validate = SpeechSynthesisDataset(
|
||||||
return_tokens=False,
|
return_text=False,
|
||||||
|
return_tokens=True,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
@ -273,13 +277,15 @@ class LJSpeechTtsDataModule:
|
|||||||
use_fft_mag=True,
|
use_fft_mag=True,
|
||||||
)
|
)
|
||||||
test = SpeechSynthesisDataset(
|
test = SpeechSynthesisDataset(
|
||||||
return_tokens=False,
|
return_text=False,
|
||||||
|
return_tokens=True,
|
||||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
test = SpeechSynthesisDataset(
|
test = SpeechSynthesisDataset(
|
||||||
return_tokens=False,
|
return_text=False,
|
||||||
|
return_tokens=True,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
@ -14,15 +14,15 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.distributed as dist
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
|
from pathlib import Path
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
@ -9,7 +9,8 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from generator import VITSGenerator
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from hifigan import (
|
from hifigan import (
|
||||||
HiFiGANMultiPeriodDiscriminator,
|
HiFiGANMultiPeriodDiscriminator,
|
||||||
HiFiGANMultiScaleDiscriminator,
|
HiFiGANMultiScaleDiscriminator,
|
||||||
@ -24,8 +25,9 @@ from loss import (
|
|||||||
KLDivergenceLoss,
|
KLDivergenceLoss,
|
||||||
MelSpectrogramLoss,
|
MelSpectrogramLoss,
|
||||||
)
|
)
|
||||||
from torch.cuda.amp import autocast
|
|
||||||
from utils import get_segments
|
from utils import get_segments
|
||||||
|
from generator import VITSGenerator
|
||||||
|
|
||||||
|
|
||||||
AVAILABLE_GENERATERS = {
|
AVAILABLE_GENERATERS = {
|
||||||
"vits_generator": VITSGenerator,
|
"vits_generator": VITSGenerator,
|
||||||
@ -570,6 +572,7 @@ class VITS(nn.Module):
|
|||||||
self,
|
self,
|
||||||
text: torch.Tensor,
|
text: torch.Tensor,
|
||||||
text_lengths: torch.Tensor,
|
text_lengths: torch.Tensor,
|
||||||
|
sids: Optional[torch.Tensor] = None,
|
||||||
durations: Optional[torch.Tensor] = None,
|
durations: Optional[torch.Tensor] = None,
|
||||||
noise_scale: float = 0.667,
|
noise_scale: float = 0.667,
|
||||||
noise_scale_dur: float = 0.8,
|
noise_scale_dur: float = 0.8,
|
||||||
@ -582,6 +585,7 @@ class VITS(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
text (Tensor): Input text index tensor (B, T_text).
|
text (Tensor): Input text index tensor (B, T_text).
|
||||||
text_lengths (Tensor): Input text index tensor (B,).
|
text_lengths (Tensor): Input text index tensor (B,).
|
||||||
|
sids (Tensor): Speaker index tensor (B,).
|
||||||
noise_scale (float): Noise scale value for flow.
|
noise_scale (float): Noise scale value for flow.
|
||||||
noise_scale_dur (float): Noise scale value for duration predictor.
|
noise_scale_dur (float): Noise scale value for duration predictor.
|
||||||
alpha (float): Alpha parameter to control the speed of generated speech.
|
alpha (float): Alpha parameter to control the speed of generated speech.
|
||||||
@ -596,6 +600,7 @@ class VITS(nn.Module):
|
|||||||
wav, att_w, dur = self.generator.inference(
|
wav, att_w, dur = self.generator.inference(
|
||||||
text=text,
|
text=text,
|
||||||
text_lengths=text_lengths,
|
text_lengths=text_lengths,
|
||||||
|
sids=sids,
|
||||||
noise_scale=noise_scale,
|
noise_scale=noise_scale,
|
||||||
noise_scale_dur=noise_scale_dur,
|
noise_scale_dur=noise_scale_dur,
|
||||||
alpha=alpha,
|
alpha=alpha,
|
||||||
|
@ -9,8 +9,9 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
|
import logging
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
Loading…
x
Reference in New Issue
Block a user