mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
Merge branch 'master' into master
This commit is contained in:
commit
6ffd624df2
2
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
2
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
@ -56,7 +56,7 @@ function infer() {
|
||||
|
||||
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
|
||||
|
||||
./matcha/inference.py \
|
||||
./matcha/infer.py \
|
||||
--epoch 1 \
|
||||
--exp-dir ./matcha/exp \
|
||||
--tokens data/tokens.txt \
|
||||
|
2
.github/workflows/style_check.yml
vendored
2
.github/workflows/style_check.yml
vendored
@ -36,7 +36,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.8]
|
||||
python-version: [3.10.15]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
|
@ -31,15 +31,6 @@ from piper_phonemize import phonemize_espeak
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def remove_punc_to_upper(text: str) -> str:
|
||||
text = text.replace("‘", "'")
|
||||
text = text.replace("’", "'")
|
||||
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
|
||||
s_list = [x.upper() if x in tokens else " " for x in text]
|
||||
s = " ".join("".join(s_list).split()).strip()
|
||||
return s
|
||||
|
||||
|
||||
def prepare_tokens_libritts():
|
||||
output_dir = Path("data/spectrogram")
|
||||
prefix = "libritts"
|
||||
@ -72,7 +63,7 @@ def prepare_tokens_libritts():
|
||||
for t in tokens_list:
|
||||
tokens.extend(t)
|
||||
cut.tokens = tokens
|
||||
cut.supervisions[0].normalized_text = remove_punc_to_upper(text)
|
||||
cut.supervisions[0].normalized_text = text
|
||||
|
||||
new_cuts.append(cut)
|
||||
|
||||
|
@ -84,7 +84,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
if [ ! -f data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz ]; then
|
||||
cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \
|
||||
<(gunzip -c data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz) \
|
||||
<(gunzip -c data/spectrogramlibritts_cuts_train-other-500.jsonl.gz) | \
|
||||
<(gunzip -c data/spectrogram/libritts_cuts_train-other-500.jsonl.gz) | \
|
||||
shuf | gzip -c > data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz
|
||||
fi
|
||||
|
||||
|
@ -131,12 +131,12 @@ To inference, use:
|
||||
|
||||
wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
|
||||
|
||||
./matcha/inference \
|
||||
./matcha/infer.py \
|
||||
--exp-dir ./matcha/exp-new-3 \
|
||||
--epoch 4000 \
|
||||
--tokens ./data/tokens.txt \
|
||||
--vocoder ./generator_v1 \
|
||||
--input-text "how are you doing?"
|
||||
--input-text "how are you doing?" \
|
||||
--output-wav ./generated.wav
|
||||
```
|
||||
|
||||
|
1
egs/ljspeech/TTS/local/audio.py
Symbolic link
1
egs/ljspeech/TTS/local/audio.py
Symbolic link
@ -0,0 +1 @@
|
||||
../matcha/audio.py
|
@ -27,102 +27,17 @@ The generated fbank features are saved in data/fbank.
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fbank import MatchaFbank, MatchaFbankConfig
|
||||
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
|
||||
from lhotse.audio import RecordingSet
|
||||
from lhotse.features.base import FeatureExtractor, register_extractor
|
||||
from lhotse.supervision import SupervisionSet
|
||||
from lhotse.utils import Seconds, compute_num_frames
|
||||
from matcha.audio import mel_spectrogram
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyFbankConfig:
|
||||
n_fft: int
|
||||
n_mels: int
|
||||
sampling_rate: int
|
||||
hop_length: int
|
||||
win_length: int
|
||||
f_min: float
|
||||
f_max: float
|
||||
|
||||
|
||||
@register_extractor
|
||||
class MyFbank(FeatureExtractor):
|
||||
|
||||
name = "MyFbank"
|
||||
config_type = MyFbankConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config=config)
|
||||
|
||||
@property
|
||||
def device(self) -> Union[str, torch.device]:
|
||||
return self.config.device
|
||||
|
||||
def feature_dim(self, sampling_rate: int) -> int:
|
||||
return self.config.n_mels
|
||||
|
||||
def extract(
|
||||
self,
|
||||
samples: np.ndarray,
|
||||
sampling_rate: int,
|
||||
) -> torch.Tensor:
|
||||
# Check for sampling rate compatibility.
|
||||
expected_sr = self.config.sampling_rate
|
||||
assert sampling_rate == expected_sr, (
|
||||
f"Mismatched sampling rate: extractor expects {expected_sr}, "
|
||||
f"got {sampling_rate}"
|
||||
)
|
||||
samples = torch.from_numpy(samples)
|
||||
assert samples.ndim == 2, samples.shape
|
||||
assert samples.shape[0] == 1, samples.shape
|
||||
|
||||
mel = (
|
||||
mel_spectrogram(
|
||||
samples,
|
||||
self.config.n_fft,
|
||||
self.config.n_mels,
|
||||
self.config.sampling_rate,
|
||||
self.config.hop_length,
|
||||
self.config.win_length,
|
||||
self.config.f_min,
|
||||
self.config.f_max,
|
||||
center=False,
|
||||
)
|
||||
.squeeze()
|
||||
.t()
|
||||
)
|
||||
|
||||
assert mel.ndim == 2, mel.shape
|
||||
assert mel.shape[1] == self.config.n_mels, mel.shape
|
||||
|
||||
num_frames = compute_num_frames(
|
||||
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
|
||||
)
|
||||
|
||||
if mel.shape[0] > num_frames:
|
||||
mel = mel[:num_frames]
|
||||
elif mel.shape[0] < num_frames:
|
||||
mel = mel.unsqueeze(0)
|
||||
mel = torch.nn.functional.pad(
|
||||
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
|
||||
).squeeze(0)
|
||||
|
||||
return mel.numpy()
|
||||
|
||||
@property
|
||||
def frame_shift(self) -> Seconds:
|
||||
return self.config.hop_length / self.config.sampling_rate
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -149,7 +64,7 @@ def compute_fbank_ljspeech(num_jobs: int):
|
||||
logging.info(f"num_jobs: {num_jobs}")
|
||||
logging.info(f"src_dir: {src_dir}")
|
||||
logging.info(f"output_dir: {output_dir}")
|
||||
config = MyFbankConfig(
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=22050,
|
||||
@ -170,7 +85,7 @@ def compute_fbank_ljspeech(num_jobs: int):
|
||||
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
|
||||
)
|
||||
|
||||
extractor = MyFbank(config)
|
||||
extractor = MatchaFbank(config)
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||
|
1
egs/ljspeech/TTS/local/fbank.py
Symbolic link
1
egs/ljspeech/TTS/local/fbank.py
Symbolic link
@ -0,0 +1 @@
|
||||
../matcha/fbank.py
|
@ -33,7 +33,6 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from compute_fbank_ljspeech import MyFbank
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.dataset.speech_synthesis import validate_for_tts
|
||||
|
||||
|
@ -1 +0,0 @@
|
||||
../local/compute_fbank_ljspeech.py
|
@ -7,7 +7,7 @@ from typing import Any, Dict
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
from inference import load_vocoder
|
||||
from infer import load_vocoder
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
||||
|
88
egs/ljspeech/TTS/matcha/fbank.py
Normal file
88
egs/ljspeech/TTS/matcha/fbank.py
Normal file
@ -0,0 +1,88 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from audio import mel_spectrogram
|
||||
from lhotse.features.base import FeatureExtractor, register_extractor
|
||||
from lhotse.utils import Seconds, compute_num_frames
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchaFbankConfig:
|
||||
n_fft: int
|
||||
n_mels: int
|
||||
sampling_rate: int
|
||||
hop_length: int
|
||||
win_length: int
|
||||
f_min: float
|
||||
f_max: float
|
||||
|
||||
|
||||
@register_extractor
|
||||
class MatchaFbank(FeatureExtractor):
|
||||
|
||||
name = "MatchaFbank"
|
||||
config_type = MatchaFbankConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config=config)
|
||||
|
||||
@property
|
||||
def device(self) -> Union[str, torch.device]:
|
||||
return self.config.device
|
||||
|
||||
def feature_dim(self, sampling_rate: int) -> int:
|
||||
return self.config.n_mels
|
||||
|
||||
def extract(
|
||||
self,
|
||||
samples: np.ndarray,
|
||||
sampling_rate: int,
|
||||
) -> torch.Tensor:
|
||||
# Check for sampling rate compatibility.
|
||||
expected_sr = self.config.sampling_rate
|
||||
assert sampling_rate == expected_sr, (
|
||||
f"Mismatched sampling rate: extractor expects {expected_sr}, "
|
||||
f"got {sampling_rate}"
|
||||
)
|
||||
samples = torch.from_numpy(samples)
|
||||
assert samples.ndim == 2, samples.shape
|
||||
assert samples.shape[0] == 1, samples.shape
|
||||
|
||||
mel = (
|
||||
mel_spectrogram(
|
||||
samples,
|
||||
self.config.n_fft,
|
||||
self.config.n_mels,
|
||||
self.config.sampling_rate,
|
||||
self.config.hop_length,
|
||||
self.config.win_length,
|
||||
self.config.f_min,
|
||||
self.config.f_max,
|
||||
center=False,
|
||||
)
|
||||
.squeeze()
|
||||
.t()
|
||||
)
|
||||
|
||||
assert mel.ndim == 2, mel.shape
|
||||
assert mel.shape[1] == self.config.n_mels, mel.shape
|
||||
|
||||
num_frames = compute_num_frames(
|
||||
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
|
||||
)
|
||||
|
||||
if mel.shape[0] > num_frames:
|
||||
mel = mel[:num_frames]
|
||||
elif mel.shape[0] < num_frames:
|
||||
mel = mel.unsqueeze(0)
|
||||
mel = torch.nn.functional.pad(
|
||||
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
|
||||
).squeeze(0)
|
||||
|
||||
return mel.numpy()
|
||||
|
||||
@property
|
||||
def frame_shift(self) -> Seconds:
|
||||
return self.config.hop_length / self.config.sampling_rate
|
328
egs/ljspeech/TTS/matcha/infer.py
Executable file
328
egs/ljspeech/TTS/matcha/infer.py
Executable file
@ -0,0 +1,328 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hifigan.config import v1, v2, v3
|
||||
from hifigan.denoiser import Denoiser
|
||||
from hifigan.models import Generator as HiFiGAN
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.utils import AttributeDict, setup_logger
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=4000,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=Path,
|
||||
default="matcha/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocoder",
|
||||
type=Path,
|
||||
default="./generator_v1",
|
||||
help="Path to the vocoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=Path,
|
||||
default="data/tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cmvn",
|
||||
type=str,
|
||||
default="data/fbank/cmvn.json",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
# The following arguments are used for inference on single text
|
||||
parser.add_argument(
|
||||
"--input-text",
|
||||
type=str,
|
||||
required=False,
|
||||
help="The text to generate speech for",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-wav",
|
||||
type=str,
|
||||
required=False,
|
||||
help="The filename of the wave to save the generated speech",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=22050,
|
||||
help="The sampling rate of the generated speech (default: 22050 for LJSpeech)",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def load_vocoder(checkpoint_path: Path) -> nn.Module:
|
||||
checkpoint_path = str(checkpoint_path)
|
||||
if checkpoint_path.endswith("v1"):
|
||||
h = AttributeDict(v1)
|
||||
elif checkpoint_path.endswith("v2"):
|
||||
h = AttributeDict(v2)
|
||||
elif checkpoint_path.endswith("v3"):
|
||||
h = AttributeDict(v3)
|
||||
else:
|
||||
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")
|
||||
|
||||
hifigan = HiFiGAN(h).to("cpu")
|
||||
hifigan.load_state_dict(
|
||||
torch.load(checkpoint_path, map_location="cpu")["generator"]
|
||||
)
|
||||
_ = hifigan.eval()
|
||||
hifigan.remove_weight_norm()
|
||||
return hifigan
|
||||
|
||||
|
||||
def to_waveform(
|
||||
mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module
|
||||
) -> torch.Tensor:
|
||||
audio = vocoder(mel).clamp(-1, 1)
|
||||
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
|
||||
return audio.squeeze()
|
||||
|
||||
|
||||
def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict:
|
||||
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
||||
x = torch.tensor(x, dtype=torch.long, device=device)
|
||||
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
|
||||
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
|
||||
|
||||
|
||||
def synthesize(
|
||||
model: nn.Module,
|
||||
tokenizer: Tokenizer,
|
||||
n_timesteps: int,
|
||||
text: str,
|
||||
length_scale: float,
|
||||
temperature: float,
|
||||
device: str = "cpu",
|
||||
spks=None,
|
||||
) -> dict:
|
||||
text_processed = process_text(text=text, tokenizer=tokenizer, device=device)
|
||||
start_t = dt.datetime.now()
|
||||
output = model.synthesise(
|
||||
text_processed["x"],
|
||||
text_processed["x_lengths"],
|
||||
n_timesteps=n_timesteps,
|
||||
temperature=temperature,
|
||||
spks=spks,
|
||||
length_scale=length_scale,
|
||||
)
|
||||
# merge everything to one dict
|
||||
output.update({"start_t": start_t, **text_processed})
|
||||
return output
|
||||
|
||||
|
||||
def infer_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
vocoder: nn.Module,
|
||||
denoiser: nn.Module,
|
||||
tokenizer: Tokenizer,
|
||||
) -> None:
|
||||
"""Decode dataset.
|
||||
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
tokenizer:
|
||||
Used to convert text to phonemes.
|
||||
"""
|
||||
|
||||
device = next(model.parameters()).device
|
||||
num_cuts = 0
|
||||
log_interval = 5
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
batch_size = len(batch["tokens"])
|
||||
|
||||
texts = [c.supervisions[0].normalized_text for c in batch["cut"]]
|
||||
|
||||
audio = batch["audio"]
|
||||
audio_lens = batch["audio_lens"].tolist()
|
||||
cut_ids = [cut.id for cut in batch["cut"]]
|
||||
|
||||
for i in range(batch_size):
|
||||
output = synthesize(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
n_timesteps=params.n_timesteps,
|
||||
text=texts[i],
|
||||
length_scale=params.length_scale,
|
||||
temperature=params.temperature,
|
||||
device=device,
|
||||
)
|
||||
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
||||
|
||||
sf.write(
|
||||
file=params.save_wav_dir / f"{cut_ids[i]}_pred.wav",
|
||||
data=output["waveform"],
|
||||
samplerate=params.data_args.sampling_rate,
|
||||
subtype="PCM_16",
|
||||
)
|
||||
sf.write(
|
||||
file=params.save_wav_dir / f"{cut_ids[i]}_gt.wav",
|
||||
data=audio[i].numpy(),
|
||||
samplerate=params.data_args.sampling_rate,
|
||||
subtype="PCM_16",
|
||||
)
|
||||
|
||||
num_cuts += batch_size
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LJSpeechTtsDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}"
|
||||
|
||||
params.res_dir = params.exp_dir / "infer" / params.suffix
|
||||
params.save_wav_dir = params.res_dir / "wav"
|
||||
params.save_wav_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
|
||||
logging.info("Infer started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.model_args.n_vocab = params.vocab_size
|
||||
|
||||
with open(params.cmvn) as f:
|
||||
stats = json.load(f)
|
||||
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.data_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
|
||||
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
|
||||
# Number of ODE Solver steps
|
||||
params.n_timesteps = 2
|
||||
|
||||
# Changes to the speaking rate
|
||||
params.length_scale = 1.0
|
||||
|
||||
# Sampling temperature
|
||||
params.temperature = 0.667
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# we need cut ids to organize tts results.
|
||||
args.return_cuts = True
|
||||
ljspeech = LJSpeechTtsDataModule(args)
|
||||
|
||||
test_cuts = ljspeech.test_cuts()
|
||||
test_dl = ljspeech.test_dataloaders(test_cuts)
|
||||
|
||||
if not Path(params.vocoder).is_file():
|
||||
raise ValueError(f"{params.vocoder} does not exist")
|
||||
|
||||
vocoder = load_vocoder(params.vocoder)
|
||||
vocoder.to(device)
|
||||
|
||||
denoiser = Denoiser(vocoder, mode="zeros")
|
||||
denoiser.to(device)
|
||||
|
||||
if params.input_text is not None and params.output_wav is not None:
|
||||
logging.info("Synthesizing a single text")
|
||||
output = synthesize(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
n_timesteps=params.n_timesteps,
|
||||
text=params.input_text,
|
||||
length_scale=params.length_scale,
|
||||
temperature=params.temperature,
|
||||
device=device,
|
||||
)
|
||||
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
||||
|
||||
sf.write(
|
||||
file=params.output_wav,
|
||||
data=output["waveform"],
|
||||
samplerate=params.sampling_rate,
|
||||
subtype="PCM_16",
|
||||
)
|
||||
else:
|
||||
logging.info("Decoding the test set")
|
||||
infer_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
vocoder=vocoder,
|
||||
denoiser=denoiser,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,199 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from matcha.hifigan.config import v1, v2, v3
|
||||
from matcha.hifigan.denoiser import Denoiser
|
||||
from matcha.hifigan.models import Generator as HiFiGAN
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.utils import AttributeDict
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=4000,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=Path,
|
||||
default="matcha/exp-new-3",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocoder",
|
||||
type=Path,
|
||||
default="./generator_v1",
|
||||
help="Path to the vocoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=Path,
|
||||
default="data/tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cmvn",
|
||||
type=str,
|
||||
default="data/fbank/cmvn.json",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input-text",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The text to generate speech for",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-wav",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The filename of the wave to save the generated speech",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def load_vocoder(checkpoint_path):
|
||||
checkpoint_path = str(checkpoint_path)
|
||||
if checkpoint_path.endswith("v1"):
|
||||
h = AttributeDict(v1)
|
||||
elif checkpoint_path.endswith("v2"):
|
||||
h = AttributeDict(v2)
|
||||
elif checkpoint_path.endswith("v3"):
|
||||
h = AttributeDict(v3)
|
||||
else:
|
||||
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")
|
||||
|
||||
hifigan = HiFiGAN(h).to("cpu")
|
||||
hifigan.load_state_dict(
|
||||
torch.load(checkpoint_path, map_location="cpu")["generator"]
|
||||
)
|
||||
_ = hifigan.eval()
|
||||
hifigan.remove_weight_norm()
|
||||
return hifigan
|
||||
|
||||
|
||||
def to_waveform(mel, vocoder, denoiser):
|
||||
audio = vocoder(mel).clamp(-1, 1)
|
||||
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
|
||||
return audio.cpu().squeeze()
|
||||
|
||||
|
||||
def process_text(text: str, tokenizer):
|
||||
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
||||
x = torch.tensor(x, dtype=torch.long)
|
||||
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
|
||||
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
|
||||
|
||||
|
||||
def synthesise(
|
||||
model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None
|
||||
):
|
||||
text_processed = process_text(text, tokenizer)
|
||||
start_t = dt.datetime.now()
|
||||
output = model.synthesise(
|
||||
text_processed["x"],
|
||||
text_processed["x_lengths"],
|
||||
n_timesteps=n_timesteps,
|
||||
temperature=temperature,
|
||||
spks=spks,
|
||||
length_scale=length_scale,
|
||||
)
|
||||
# merge everything to one dict
|
||||
output.update({"start_t": start_t, **text_processed})
|
||||
return output
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.model_args.n_vocab = params.vocab_size
|
||||
|
||||
with open(params.cmvn) as f:
|
||||
stats = json.load(f)
|
||||
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.data_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
|
||||
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if not Path(f"{params.exp_dir}/epoch-{params.epoch}.pt").is_file():
|
||||
raise ValueError("{params.exp_dir}/epoch-{params.epoch}.pt does not exist")
|
||||
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.eval()
|
||||
|
||||
if not Path(params.vocoder).is_file():
|
||||
raise ValueError(f"{params.vocoder} does not exist")
|
||||
|
||||
vocoder = load_vocoder(params.vocoder)
|
||||
denoiser = Denoiser(vocoder, mode="zeros")
|
||||
|
||||
# Number of ODE Solver steps
|
||||
n_timesteps = 2
|
||||
|
||||
# Changes to the speaking rate
|
||||
length_scale = 1.0
|
||||
|
||||
# Sampling temperature
|
||||
temperature = 0.667
|
||||
|
||||
output = synthesise(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
n_timesteps=n_timesteps,
|
||||
text=params.input_text,
|
||||
length_scale=length_scale,
|
||||
temperature=temperature,
|
||||
)
|
||||
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
||||
|
||||
sf.write(params.output_wav, output["waveform"], 22050, "PCM_16")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
main()
|
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||
from conformer import ConformerBlock
|
||||
from diffusers.models.activations import get_activation
|
||||
from einops import pack, rearrange, repeat
|
||||
from matcha.models.components.transformer import BasicTransformerBlock
|
||||
from models.components.transformer import BasicTransformerBlock
|
||||
|
||||
|
||||
class SinusoidalPosEmb(torch.nn.Module):
|
||||
|
@ -2,7 +2,7 @@ from abc import ABC
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from matcha.models.components.decoder import Decoder
|
||||
from models.components.decoder import Decoder
|
||||
|
||||
|
||||
class BASECFM(torch.nn.Module, ABC):
|
||||
|
@ -5,7 +5,7 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from matcha.model import sequence_mask
|
||||
from model import sequence_mask
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
|
@ -2,17 +2,17 @@ import datetime as dt
|
||||
import math
|
||||
import random
|
||||
|
||||
import matcha.monotonic_align as monotonic_align
|
||||
import monotonic_align as monotonic_align
|
||||
import torch
|
||||
from matcha.model import (
|
||||
from model import (
|
||||
denormalize,
|
||||
duration_loss,
|
||||
fix_len_compatibility,
|
||||
generate_path,
|
||||
sequence_mask,
|
||||
)
|
||||
from matcha.models.components.flow_matching import CFM
|
||||
from matcha.models.components.text_encoder import TextEncoder
|
||||
from models.components.flow_matching import CFM
|
||||
from models.components.text_encoder import TextEncoder
|
||||
|
||||
|
||||
class MatchaTTS(torch.nn.Module): # 🍵
|
||||
|
@ -1,3 +1,3 @@
|
||||
build
|
||||
core.c
|
||||
*.so
|
||||
*.so
|
@ -1,8 +1,7 @@
|
||||
# Copied from
|
||||
# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/__init__.py
|
||||
import numpy as np
|
||||
import torch
|
||||
from matcha.monotonic_align.core import maximum_path_c
|
||||
|
||||
from .core import maximum_path_c
|
||||
|
||||
|
||||
def maximum_path(value, mask):
|
||||
|
@ -1,5 +1,3 @@
|
||||
# Copied from
|
||||
# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/core.pyx
|
||||
import numpy as np
|
||||
|
||||
cimport cython
|
||||
|
@ -1,12 +1,30 @@
|
||||
# Copied from
|
||||
# Modified from
|
||||
# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/setup.py
|
||||
from distutils.core import setup
|
||||
|
||||
import numpy
|
||||
from Cython.Build import cythonize
|
||||
from setuptools import Extension, setup
|
||||
from setuptools.command.build_ext import build_ext as _build_ext
|
||||
|
||||
|
||||
class build_ext(_build_ext):
|
||||
"""Overwrite build_ext."""
|
||||
|
||||
def finalize_options(self):
|
||||
"""Prevent numpy from thinking it is still in its setup process."""
|
||||
_build_ext.finalize_options(self)
|
||||
__builtins__.__NUMPY_SETUP__ = False
|
||||
import numpy
|
||||
|
||||
self.include_dirs.append(numpy.get_include())
|
||||
|
||||
|
||||
exts = [
|
||||
Extension(
|
||||
name="core",
|
||||
sources=["core.pyx"],
|
||||
)
|
||||
]
|
||||
setup(
|
||||
name="monotonic_align",
|
||||
ext_modules=cythonize("core.pyx"),
|
||||
include_dirs=[numpy.get_include()],
|
||||
ext_modules=cythonize(exts, language_level=3),
|
||||
cmdclass={"build_ext": build_ext},
|
||||
)
|
||||
|
@ -8,7 +8,7 @@ import logging
|
||||
import onnxruntime as ort
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from inference import load_vocoder
|
||||
from infer import load_vocoder
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
conformer==0.3.2
|
||||
diffusers # developed using version ==0.25.0
|
||||
librosa
|
||||
einops
|
@ -14,9 +14,9 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from lhotse.utils import fix_random_seed
|
||||
from matcha.model import fix_len_compatibility
|
||||
from matcha.models.matcha_tts import MatchaTTS
|
||||
from matcha.tokenizer import Tokenizer
|
||||
from model import fix_len_compatibility
|
||||
from models.matcha_tts import MatchaTTS
|
||||
from tokenizer import Tokenizer
|
||||
from torch.amp import GradScaler, autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
@ -150,7 +150,7 @@ def _get_data_params() -> AttributeDict:
|
||||
"n_spks": 1,
|
||||
"n_fft": 1024,
|
||||
"n_feats": 80,
|
||||
"sample_rate": 22050,
|
||||
"sampling_rate": 22050,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"f_min": 0,
|
||||
@ -445,11 +445,6 @@ def train_one_epoch(
|
||||
|
||||
saved_bad_model = False
|
||||
|
||||
# used to track the stats over iterations in one epoch
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
saved_bad_model = False
|
||||
|
||||
def save_bad_model(suffix: str = ""):
|
||||
save_checkpoint(
|
||||
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||
@ -493,9 +488,10 @@ def train_one_epoch(
|
||||
|
||||
loss = sum(losses.values())
|
||||
|
||||
optimizer.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
|
@ -24,7 +24,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from compute_fbank_ljspeech import MyFbank, MyFbankConfig
|
||||
from fbank import MatchaFbank, MatchaFbankConfig
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
@ -32,7 +32,6 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
DynamicBucketingSampler,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
SpeechSynthesisDataset,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
@ -177,7 +176,7 @@ class LJSpeechTtsDataModule:
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = MyFbankConfig(
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
@ -189,7 +188,7 @@ class LJSpeechTtsDataModule:
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
|
||||
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
@ -238,7 +237,7 @@ class LJSpeechTtsDataModule:
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = MyFbankConfig(
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
@ -250,7 +249,7 @@ class LJSpeechTtsDataModule:
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
|
||||
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
@ -282,7 +281,7 @@ class LJSpeechTtsDataModule:
|
||||
logging.info("About to create test dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = MyFbankConfig(
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
@ -294,7 +293,7 @@ class LJSpeechTtsDataModule:
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
|
||||
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
|
@ -25,26 +25,16 @@ log() {
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "Stage -1: build monotonic_align lib"
|
||||
if [ ! -d vits/monotonic_align/build ]; then
|
||||
cd vits/monotonic_align
|
||||
python3 setup.py build_ext --inplace
|
||||
cd ../../
|
||||
else
|
||||
log "monotonic_align lib for vits already built"
|
||||
fi
|
||||
|
||||
if [ ! -f ./matcha/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ]; then
|
||||
pushd matcha/monotonic_align
|
||||
python3 setup.py build
|
||||
mv -v build/lib.*/matcha/monotonic_align/core.*.so .
|
||||
rm -rf build
|
||||
rm core.c
|
||||
ls -lh
|
||||
popd
|
||||
else
|
||||
log "monotonic_align lib for matcha-tts already built"
|
||||
fi
|
||||
log "Stage -1: build monotonic_align lib (used by vits and matcha recipes)"
|
||||
for recipe in vits matcha; do
|
||||
if [ ! -d $recipe/monotonic_align/build ]; then
|
||||
cd $recipe/monotonic_align
|
||||
python3 setup.py build_ext --inplace
|
||||
cd ../../
|
||||
else
|
||||
log "monotonic_align lib for $recipe already built"
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
|
@ -234,7 +234,7 @@ def main():
|
||||
logging.info(f"Number of parameters in discriminator: {num_param_d}")
|
||||
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
# we need cut ids to organize tts results.
|
||||
args.return_cuts = True
|
||||
ljspeech = LJSpeechTtsDataModule(args)
|
||||
|
||||
|
3
egs/ljspeech/TTS/vits/monotonic_align/.gitignore
vendored
Normal file
3
egs/ljspeech/TTS/vits/monotonic_align/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
build
|
||||
core.c
|
||||
*.so
|
@ -18,7 +18,6 @@
|
||||
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
from vits import VITS
|
||||
|
||||
|
||||
def test_model_type(model_type):
|
||||
|
@ -52,13 +52,19 @@ def get_parser():
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speed-perturb",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -104,6 +110,9 @@ def compute_fbank_kespeech_dev_test(args):
|
||||
keep_overlapping=False, min_duration=None
|
||||
)
|
||||
|
||||
if args.speed_perturb:
|
||||
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
|
||||
logging.info("Computing features")
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
|
@ -106,6 +106,14 @@ def get_parser():
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--speed-perturb",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -158,6 +166,9 @@ def compute_fbank_kespeech_splits(args):
|
||||
keep_overlapping=False, min_duration=None
|
||||
)
|
||||
|
||||
if args.speed_perturb:
|
||||
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
|
||||
logging.info("Computing features")
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
|
@ -516,9 +516,19 @@ def main():
|
||||
for idx, part in enumerate(cut_sets):
|
||||
if args.audio_extractor:
|
||||
if args.audio_extractor == "Encodec":
|
||||
storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}"
|
||||
if split > 1:
|
||||
storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx}"
|
||||
else:
|
||||
storage_path = (
|
||||
f"{args.output_dir}/{args.prefix}_encodec_{partition}"
|
||||
)
|
||||
else:
|
||||
storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}"
|
||||
if split > 1:
|
||||
storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx}"
|
||||
else:
|
||||
storage_path = (
|
||||
f"{args.output_dir}/{args.prefix}_fbank_{partition}"
|
||||
)
|
||||
|
||||
if args.prefix.lower() in [
|
||||
"ljspeech",
|
||||
@ -587,9 +597,11 @@ def main():
|
||||
].normalized_text, "normalized_text is None"
|
||||
|
||||
# Save each part with an index if split > 1
|
||||
cuts_filename = (
|
||||
f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}"
|
||||
)
|
||||
if split > 1:
|
||||
cuts_filename = f"{prefix}cuts_{partition}.{idx}.{args.suffix}"
|
||||
else:
|
||||
cuts_filename = f"{prefix}cuts_{partition}.{args.suffix}"
|
||||
|
||||
part.to_file(f"{args.output_dir}/{cuts_filename}")
|
||||
logging.info(f"Saved {cuts_filename}")
|
||||
|
||||
|
@ -86,7 +86,7 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
default="exp/vallf_nano_full/checkpoint-100000.pt",
|
||||
default="./valle/exp/checkpoint-100000.pt",
|
||||
help="Path to the saved checkpoint.",
|
||||
)
|
||||
|
||||
|
2
egs/wenetspeech4tts/TTS/valle/requirements.txt
Normal file
2
egs/wenetspeech4tts/TTS/valle/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
||||
phonemizer==3.2.1
|
||||
git+https://github.com/facebookresearch/encodec.git
|
@ -4,6 +4,7 @@
|
||||
# Mingshuang Luo)
|
||||
# Copyright 2023 (authors: Feiteng Li)
|
||||
# Copyright 2024 (authors: Yuekai Zhang)
|
||||
# Copyright 2024 Tsinghua University (authors: Zengrui Jin,)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -48,10 +49,8 @@ python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
@ -216,7 +215,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="exp/valle_dev",
|
||||
default="./valle/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -686,9 +685,9 @@ def compute_validation_loss(
|
||||
output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
if isinstance(model, DDP):
|
||||
model.module.visualize(predicts, batch, output_dir=output_dir)
|
||||
model.module.visualize(predicts, batch, tokenizer, output_dir=output_dir)
|
||||
else:
|
||||
model.visualize(predicts, batch, output_dir=output_dir)
|
||||
model.visualize(predicts, batch, tokenizer, output_dir=output_dir)
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
@ -19,8 +19,11 @@ import random
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tokenizer import TextTokenCollater
|
||||
from torch import Tensor
|
||||
from torch.nn import Linear, Module
|
||||
from torch.nn import functional as F
|
||||
@ -1658,6 +1661,86 @@ class VALLE(nn.Module):
|
||||
assert len(codes) == 8
|
||||
return torch.stack(codes, dim=-1)
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
predicts: Tuple[torch.Tensor],
|
||||
batch: Dict[str, Union[List, torch.Tensor]],
|
||||
tokenizer: TextTokenCollater,
|
||||
output_dir: str,
|
||||
limit: int = 4,
|
||||
) -> None:
|
||||
audio_features = batch["features"].to("cpu").detach().numpy()
|
||||
audio_features_lens = batch["features_lens"].to("cpu").detach().numpy()
|
||||
|
||||
tokens = batch["tokens"]
|
||||
text_tokens, text_tokens_lens = tokenizer(tokens)
|
||||
assert text_tokens.ndim == 2
|
||||
|
||||
texts = batch["text"]
|
||||
utt_ids = [cut.id for cut in batch["cut"]]
|
||||
|
||||
encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
|
||||
decoder_outputs = predicts[1]
|
||||
if isinstance(decoder_outputs, list):
|
||||
decoder_outputs = decoder_outputs[-1]
|
||||
decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
|
||||
|
||||
vmin, vmax = 0, 1024 # Encodec
|
||||
|
||||
num_figures = 3
|
||||
for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
|
||||
_ = plt.figure(figsize=(14, 8 * num_figures))
|
||||
|
||||
S = text_tokens_lens[b]
|
||||
T = audio_features_lens[b]
|
||||
|
||||
# encoder
|
||||
plt.subplot(num_figures, 1, 1)
|
||||
plt.title(f"Text: {text}")
|
||||
plt.imshow(
|
||||
X=np.transpose(encoder_outputs[b]),
|
||||
cmap=plt.get_cmap("jet"),
|
||||
aspect="auto",
|
||||
interpolation="nearest",
|
||||
)
|
||||
plt.gca().invert_yaxis()
|
||||
plt.axvline(x=S - 0.4, linewidth=2, color="r")
|
||||
plt.xlabel("Encoder Output")
|
||||
plt.colorbar()
|
||||
|
||||
# decoder
|
||||
plt.subplot(num_figures, 1, 2)
|
||||
plt.imshow(
|
||||
X=np.transpose(decoder_outputs[b]),
|
||||
cmap=plt.get_cmap("jet"),
|
||||
aspect="auto",
|
||||
interpolation="nearest",
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
)
|
||||
plt.gca().invert_yaxis()
|
||||
plt.axvline(x=T - 0.4, linewidth=2, color="r")
|
||||
plt.xlabel("Decoder Output")
|
||||
plt.colorbar()
|
||||
|
||||
# target
|
||||
plt.subplot(num_figures, 1, 3)
|
||||
plt.imshow(
|
||||
X=np.transpose(audio_features[b]),
|
||||
cmap=plt.get_cmap("jet"),
|
||||
aspect="auto",
|
||||
interpolation="nearest",
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
)
|
||||
plt.gca().invert_yaxis()
|
||||
plt.axvline(x=T - 0.4, linewidth=2, color="r")
|
||||
plt.xlabel("Decoder Target")
|
||||
plt.colorbar()
|
||||
|
||||
plt.savefig(f"{output_dir}/{utt_id}.png")
|
||||
plt.close()
|
||||
|
||||
|
||||
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
||||
def top_k_top_p_filtering(
|
||||
|
@ -974,7 +974,16 @@ def run(rank, world_size, args):
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0)
|
||||
parameters_names = []
|
||||
parameters_names.append(
|
||||
[name_param_pair[0] for name_param_pair in model.named_parameters()]
|
||||
)
|
||||
optimizer = ScaledAdam(
|
||||
model.parameters(),
|
||||
lr=params.base_lr,
|
||||
clipping_scale=2.0,
|
||||
parameters_names=parameters_names,
|
||||
)
|
||||
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user