Performed end to end testing on the matcha recipe (#1797)

* minor fixes to the `ljspeech/matcha` recipe
This commit is contained in:
zr_jin 2024-12-08 03:18:15 +08:00 committed by GitHub
parent 6e6b022e41
commit 1c4dd464a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 485 additions and 350 deletions

View File

@ -56,7 +56,7 @@ function infer() {
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
./matcha/inference.py \
./matcha/infer.py \
--epoch 1 \
--exp-dir ./matcha/exp \
--tokens data/tokens.txt \

View File

@ -131,12 +131,12 @@ To inference, use:
wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
./matcha/inference \
./matcha/synth.py \
--exp-dir ./matcha/exp-new-3 \
--epoch 4000 \
--tokens ./data/tokens.txt \
--vocoder ./generator_v1 \
--input-text "how are you doing?"
--input-text "how are you doing?" \
--output-wav ./generated.wav
```

View File

@ -0,0 +1 @@
../matcha/audio.py

View File

@ -27,102 +27,17 @@ The generated fbank features are saved in data/fbank.
import argparse
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Union
import numpy as np
import torch
from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
from lhotse.audio import RecordingSet
from lhotse.features.base import FeatureExtractor, register_extractor
from lhotse.supervision import SupervisionSet
from lhotse.utils import Seconds, compute_num_frames
from matcha.audio import mel_spectrogram
from icefall.utils import get_executor
@dataclass
class MyFbankConfig:
n_fft: int
n_mels: int
sampling_rate: int
hop_length: int
win_length: int
f_min: float
f_max: float
@register_extractor
class MyFbank(FeatureExtractor):
name = "MyFbank"
config_type = MyFbankConfig
def __init__(self, config):
super().__init__(config=config)
@property
def device(self) -> Union[str, torch.device]:
return self.config.device
def feature_dim(self, sampling_rate: int) -> int:
return self.config.n_mels
def extract(
self,
samples: np.ndarray,
sampling_rate: int,
) -> torch.Tensor:
# Check for sampling rate compatibility.
expected_sr = self.config.sampling_rate
assert sampling_rate == expected_sr, (
f"Mismatched sampling rate: extractor expects {expected_sr}, "
f"got {sampling_rate}"
)
samples = torch.from_numpy(samples)
assert samples.ndim == 2, samples.shape
assert samples.shape[0] == 1, samples.shape
mel = (
mel_spectrogram(
samples,
self.config.n_fft,
self.config.n_mels,
self.config.sampling_rate,
self.config.hop_length,
self.config.win_length,
self.config.f_min,
self.config.f_max,
center=False,
)
.squeeze()
.t()
)
assert mel.ndim == 2, mel.shape
assert mel.shape[1] == self.config.n_mels, mel.shape
num_frames = compute_num_frames(
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
)
if mel.shape[0] > num_frames:
mel = mel[:num_frames]
elif mel.shape[0] < num_frames:
mel = mel.unsqueeze(0)
mel = torch.nn.functional.pad(
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
).squeeze(0)
return mel.numpy()
@property
def frame_shift(self) -> Seconds:
return self.config.hop_length / self.config.sampling_rate
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -149,7 +64,7 @@ def compute_fbank_ljspeech(num_jobs: int):
logging.info(f"num_jobs: {num_jobs}")
logging.info(f"src_dir: {src_dir}")
logging.info(f"output_dir: {output_dir}")
config = MyFbankConfig(
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=22050,
@ -170,7 +85,7 @@ def compute_fbank_ljspeech(num_jobs: int):
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
)
extractor = MyFbank(config)
extractor = MatchaFbank(config)
with get_executor() as ex: # Initialize the executor only once.
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"

View File

@ -0,0 +1 @@
../matcha/fbank.py

View File

@ -1 +0,0 @@
../local/compute_fbank_ljspeech.py

View File

@ -7,7 +7,7 @@ from typing import Any, Dict
import onnx
import torch
from inference import load_vocoder
from infer import load_vocoder
def add_meta_data(filename: str, meta_data: Dict[str, Any]):

View File

@ -0,0 +1,88 @@
from dataclasses import dataclass
from typing import Union
import numpy as np
import torch
from audio import mel_spectrogram
from lhotse.features.base import FeatureExtractor, register_extractor
from lhotse.utils import Seconds, compute_num_frames
@dataclass
class MatchaFbankConfig:
n_fft: int
n_mels: int
sampling_rate: int
hop_length: int
win_length: int
f_min: float
f_max: float
@register_extractor
class MatchaFbank(FeatureExtractor):
name = "MatchaFbank"
config_type = MatchaFbankConfig
def __init__(self, config):
super().__init__(config=config)
@property
def device(self) -> Union[str, torch.device]:
return self.config.device
def feature_dim(self, sampling_rate: int) -> int:
return self.config.n_mels
def extract(
self,
samples: np.ndarray,
sampling_rate: int,
) -> torch.Tensor:
# Check for sampling rate compatibility.
expected_sr = self.config.sampling_rate
assert sampling_rate == expected_sr, (
f"Mismatched sampling rate: extractor expects {expected_sr}, "
f"got {sampling_rate}"
)
samples = torch.from_numpy(samples)
assert samples.ndim == 2, samples.shape
assert samples.shape[0] == 1, samples.shape
mel = (
mel_spectrogram(
samples,
self.config.n_fft,
self.config.n_mels,
self.config.sampling_rate,
self.config.hop_length,
self.config.win_length,
self.config.f_min,
self.config.f_max,
center=False,
)
.squeeze()
.t()
)
assert mel.ndim == 2, mel.shape
assert mel.shape[1] == self.config.n_mels, mel.shape
num_frames = compute_num_frames(
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
)
if mel.shape[0] > num_frames:
mel = mel[:num_frames]
elif mel.shape[0] < num_frames:
mel = mel.unsqueeze(0)
mel = torch.nn.functional.pad(
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
).squeeze(0)
return mel.numpy()
@property
def frame_shift(self) -> Seconds:
return self.config.hop_length / self.config.sampling_rate

328
egs/ljspeech/TTS/matcha/infer.py Executable file
View File

@ -0,0 +1,328 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
import datetime as dt
import json
import logging
from pathlib import Path
import soundfile as sf
import torch
import torch.nn as nn
from hifigan.config import v1, v2, v3
from hifigan.denoiser import Denoiser
from hifigan.models import Generator as HiFiGAN
from tokenizer import Tokenizer
from train import get_model, get_params
from tts_datamodule import LJSpeechTtsDataModule
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=4000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--vocoder",
type=Path,
default="./generator_v1",
help="Path to the vocoder",
)
parser.add_argument(
"--tokens",
type=Path,
default="data/tokens.txt",
)
parser.add_argument(
"--cmvn",
type=str,
default="data/fbank/cmvn.json",
help="""Path to vocabulary.""",
)
# The following arguments are used for inference on single text
parser.add_argument(
"--input-text",
type=str,
required=False,
help="The text to generate speech for",
)
parser.add_argument(
"--output-wav",
type=str,
required=False,
help="The filename of the wave to save the generated speech",
)
parser.add_argument(
"--sampling-rate",
type=int,
default=22050,
help="The sampling rate of the generated speech (default: 22050 for LJSpeech)",
)
return parser
def load_vocoder(checkpoint_path: Path) -> nn.Module:
checkpoint_path = str(checkpoint_path)
if checkpoint_path.endswith("v1"):
h = AttributeDict(v1)
elif checkpoint_path.endswith("v2"):
h = AttributeDict(v2)
elif checkpoint_path.endswith("v3"):
h = AttributeDict(v3)
else:
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")
hifigan = HiFiGAN(h).to("cpu")
hifigan.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["generator"]
)
_ = hifigan.eval()
hifigan.remove_weight_norm()
return hifigan
def to_waveform(
mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module
) -> torch.Tensor:
audio = vocoder(mel).clamp(-1, 1)
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
return audio.squeeze()
def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict:
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
x = torch.tensor(x, dtype=torch.long, device=device)
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
def synthesize(
model: nn.Module,
tokenizer: Tokenizer,
n_timesteps: int,
text: str,
length_scale: float,
temperature: float,
device: str = "cpu",
spks=None,
) -> dict:
text_processed = process_text(text=text, tokenizer=tokenizer, device=device)
start_t = dt.datetime.now()
output = model.synthesise(
text_processed["x"],
text_processed["x_lengths"],
n_timesteps=n_timesteps,
temperature=temperature,
spks=spks,
length_scale=length_scale,
)
# merge everything to one dict
output.update({"start_t": start_t, **text_processed})
return output
def infer_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
vocoder: nn.Module,
denoiser: nn.Module,
tokenizer: Tokenizer,
) -> None:
"""Decode dataset.
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
tokenizer:
Used to convert text to phonemes.
"""
device = next(model.parameters()).device
num_cuts = 0
log_interval = 5
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
for batch_idx, batch in enumerate(dl):
batch_size = len(batch["tokens"])
texts = [c.supervisions[0].normalized_text for c in batch["cut"]]
audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]
for i in range(batch_size):
output = synthesize(
model=model,
tokenizer=tokenizer,
n_timesteps=params.n_timesteps,
text=texts[i],
length_scale=params.length_scale,
temperature=params.temperature,
device=device,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
sf.write(
file=params.save_wav_dir / f"{cut_ids[i]}_pred.wav",
data=output["waveform"],
samplerate=params.data_args.sampling_rate,
subtype="PCM_16",
)
sf.write(
file=params.save_wav_dir / f"{cut_ids[i]}_gt.wav",
data=audio[i].numpy(),
samplerate=params.data_args.sampling_rate,
subtype="PCM_16",
)
num_cuts += batch_size
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
@torch.inference_mode()
def main():
parser = get_parser()
LJSpeechTtsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.suffix = f"epoch-{params.epoch}"
params.res_dir = params.exp_dir / "infer" / params.suffix
params.save_wav_dir = params.res_dir / "wav"
params.save_wav_dir.mkdir(parents=True, exist_ok=True)
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
logging.info("Infer started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
# Number of ODE Solver steps
params.n_timesteps = 2
# Changes to the speaking rate
params.length_scale = 1.0
# Sampling temperature
params.temperature = 0.667
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval()
# we need cut ids to organize tts results.
args.return_cuts = True
ljspeech = LJSpeechTtsDataModule(args)
test_cuts = ljspeech.test_cuts()
test_dl = ljspeech.test_dataloaders(test_cuts)
if not Path(params.vocoder).is_file():
raise ValueError(f"{params.vocoder} does not exist")
vocoder = load_vocoder(params.vocoder)
vocoder.to(device)
denoiser = Denoiser(vocoder, mode="zeros")
denoiser.to(device)
if params.input_text is not None and params.output_wav is not None:
logging.info("Synthesizing a single text")
output = synthesize(
model=model,
tokenizer=tokenizer,
n_timesteps=params.n_timesteps,
text=params.input_text,
length_scale=params.length_scale,
temperature=params.temperature,
device=device,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
sf.write(
file=params.output_wav,
data=output["waveform"],
samplerate=params.sampling_rate,
subtype="PCM_16",
)
else:
logging.info("Decoding the test set")
infer_dataset(
dl=test_dl,
params=params,
model=model,
vocoder=vocoder,
denoiser=denoiser,
tokenizer=tokenizer,
)
if __name__ == "__main__":
main()

View File

@ -1,199 +0,0 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
import datetime as dt
import json
import logging
from pathlib import Path
import soundfile as sf
import torch
from matcha.hifigan.config import v1, v2, v3
from matcha.hifigan.denoiser import Denoiser
from matcha.hifigan.models import Generator as HiFiGAN
from tokenizer import Tokenizer
from train import get_model, get_params
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=4000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp-new-3",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--vocoder",
type=Path,
default="./generator_v1",
help="Path to the vocoder",
)
parser.add_argument(
"--tokens",
type=Path,
default="data/tokens.txt",
)
parser.add_argument(
"--cmvn",
type=str,
default="data/fbank/cmvn.json",
help="""Path to vocabulary.""",
)
parser.add_argument(
"--input-text",
type=str,
required=True,
help="The text to generate speech for",
)
parser.add_argument(
"--output-wav",
type=str,
required=True,
help="The filename of the wave to save the generated speech",
)
return parser
def load_vocoder(checkpoint_path):
checkpoint_path = str(checkpoint_path)
if checkpoint_path.endswith("v1"):
h = AttributeDict(v1)
elif checkpoint_path.endswith("v2"):
h = AttributeDict(v2)
elif checkpoint_path.endswith("v3"):
h = AttributeDict(v3)
else:
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")
hifigan = HiFiGAN(h).to("cpu")
hifigan.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["generator"]
)
_ = hifigan.eval()
hifigan.remove_weight_norm()
return hifigan
def to_waveform(mel, vocoder, denoiser):
audio = vocoder(mel).clamp(-1, 1)
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
return audio.cpu().squeeze()
def process_text(text: str, tokenizer):
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
x = torch.tensor(x, dtype=torch.long)
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
def synthesise(
model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None
):
text_processed = process_text(text, tokenizer)
start_t = dt.datetime.now()
output = model.synthesise(
text_processed["x"],
text_processed["x_lengths"],
n_timesteps=n_timesteps,
temperature=temperature,
spks=spks,
length_scale=length_scale,
)
# merge everything to one dict
output.update({"start_t": start_t, **text_processed})
return output
@torch.inference_mode()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not Path(f"{params.exp_dir}/epoch-{params.epoch}.pt").is_file():
raise ValueError("{params.exp_dir}/epoch-{params.epoch}.pt does not exist")
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.eval()
if not Path(params.vocoder).is_file():
raise ValueError(f"{params.vocoder} does not exist")
vocoder = load_vocoder(params.vocoder)
denoiser = Denoiser(vocoder, mode="zeros")
# Number of ODE Solver steps
n_timesteps = 2
# Changes to the speaking rate
length_scale = 1.0
# Sampling temperature
temperature = 0.667
output = synthesise(
model=model,
tokenizer=tokenizer,
n_timesteps=n_timesteps,
text=params.input_text,
length_scale=length_scale,
temperature=temperature,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
sf.write(params.output_wav, output["waveform"], 22050, "PCM_16")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
from conformer import ConformerBlock
from diffusers.models.activations import get_activation
from einops import pack, rearrange, repeat
from matcha.models.components.transformer import BasicTransformerBlock
from models.components.transformer import BasicTransformerBlock
class SinusoidalPosEmb(torch.nn.Module):

View File

@ -2,7 +2,7 @@ from abc import ABC
import torch
import torch.nn.functional as F
from matcha.models.components.decoder import Decoder
from models.components.decoder import Decoder
class BASECFM(torch.nn.Module, ABC):

View File

@ -5,7 +5,7 @@ import math
import torch
import torch.nn as nn
from einops import rearrange
from matcha.model import sequence_mask
from model import sequence_mask
class LayerNorm(nn.Module):

View File

@ -2,17 +2,17 @@ import datetime as dt
import math
import random
import matcha.monotonic_align as monotonic_align
import monotonic_align as monotonic_align
import torch
from matcha.model import (
from model import (
denormalize,
duration_loss,
fix_len_compatibility,
generate_path,
sequence_mask,
)
from matcha.models.components.flow_matching import CFM
from matcha.models.components.text_encoder import TextEncoder
from models.components.flow_matching import CFM
from models.components.text_encoder import TextEncoder
class MatchaTTS(torch.nn.Module): # 🍵

View File

@ -1,8 +1,7 @@
# Copied from
# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/__init__.py
import numpy as np
import torch
from matcha.monotonic_align.core import maximum_path_c
from .core import maximum_path_c
def maximum_path(value, mask):

View File

@ -1,5 +1,3 @@
# Copied from
# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/core.pyx
import numpy as np
cimport cython

View File

@ -1,12 +1,30 @@
# Copied from
# Modified from
# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/setup.py
from distutils.core import setup
import numpy
from Cython.Build import cythonize
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext as _build_ext
class build_ext(_build_ext):
"""Overwrite build_ext."""
def finalize_options(self):
"""Prevent numpy from thinking it is still in its setup process."""
_build_ext.finalize_options(self)
__builtins__.__NUMPY_SETUP__ = False
import numpy
self.include_dirs.append(numpy.get_include())
exts = [
Extension(
name="core",
sources=["core.pyx"],
)
]
setup(
name="monotonic_align",
ext_modules=cythonize("core.pyx"),
include_dirs=[numpy.get_include()],
ext_modules=cythonize(exts, language_level=3),
cmdclass={"build_ext": build_ext},
)

View File

@ -1,3 +1,4 @@
conformer==0.3.2
diffusers # developed using version ==0.25.0
librosa
einops

View File

@ -14,9 +14,9 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from lhotse.utils import fix_random_seed
from matcha.model import fix_len_compatibility
from matcha.models.matcha_tts import MatchaTTS
from matcha.tokenizer import Tokenizer
from model import fix_len_compatibility
from models.matcha_tts import MatchaTTS
from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
@ -150,7 +150,7 @@ def _get_data_params() -> AttributeDict:
"n_spks": 1,
"n_fft": 1024,
"n_feats": 80,
"sample_rate": 22050,
"sampling_rate": 22050,
"hop_length": 256,
"win_length": 1024,
"f_min": 0,
@ -445,11 +445,6 @@ def train_one_epoch(
saved_bad_model = False
# used to track the stats over iterations in one epoch
tot_loss = MetricsTracker()
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",

View File

@ -24,7 +24,7 @@ from pathlib import Path
from typing import Any, Dict, Optional
import torch
from compute_fbank_ljspeech import MyFbank, MyFbankConfig
from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
@ -32,7 +32,6 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
@ -177,7 +176,7 @@ class LJSpeechTtsDataModule:
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MyFbankConfig(
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
@ -189,7 +188,7 @@ class LJSpeechTtsDataModule:
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
)
@ -238,7 +237,7 @@ class LJSpeechTtsDataModule:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MyFbankConfig(
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
@ -250,7 +249,7 @@ class LJSpeechTtsDataModule:
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
)
else:
@ -282,7 +281,7 @@ class LJSpeechTtsDataModule:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MyFbankConfig(
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
@ -294,7 +293,7 @@ class LJSpeechTtsDataModule:
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
)
else:

View File

@ -25,26 +25,16 @@ log() {
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: build monotonic_align lib"
if [ ! -d vits/monotonic_align/build ]; then
cd vits/monotonic_align
log "Stage -1: build monotonic_align lib (used by vits and matcha recipes)"
for recipe in vits matcha; do
if [ ! -d $recipe/monotonic_align/build ]; then
cd $recipe/monotonic_align
python3 setup.py build_ext --inplace
cd ../../
else
log "monotonic_align lib for vits already built"
fi
if [ ! -f ./matcha/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ]; then
pushd matcha/monotonic_align
python3 setup.py build
mv -v build/lib.*/matcha/monotonic_align/core.*.so .
rm -rf build
rm core.c
ls -lh
popd
else
log "monotonic_align lib for matcha-tts already built"
log "monotonic_align lib for $recipe already built"
fi
done
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then

View File

@ -234,7 +234,7 @@ def main():
logging.info(f"Number of parameters in discriminator: {num_param_d}")
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
# we need cut ids to display recognition results.
# we need cut ids to organize tts results.
args.return_cuts = True
ljspeech = LJSpeechTtsDataModule(args)

View File

@ -0,0 +1,3 @@
build
core.c
*.so

View File

@ -18,7 +18,6 @@
from tokenizer import Tokenizer
from train import get_model, get_params
from vits import VITS
def test_model_type(model_type):