This commit is contained in:
Fangjun Kuang 2024-10-20 10:14:10 +08:00
parent 7077b4f99a
commit 6a4cb112dd
6 changed files with 220 additions and 64 deletions

View File

@ -106,6 +106,8 @@ def compute_fbank_ljspeech(num_jobs: int):
use_fft_mag=True,
low_freq=0,
high_freq=8000,
remove_dc_offset=False,
preemph_coeff=0,
# should be identical to n_feats in ../matcha/train.py
num_filters=80,
)

View File

@ -0,0 +1,84 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script compute the mean and std of the fbank features.
"""
import argparse
import json
import logging
from pathlib import Path
import torch
from lhotse import CutSet, load_manifest_lazy
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"manifest",
type=Path,
help="Path to the manifest file",
)
parser.add_argument(
"cmvn",
type=Path,
help="Path to the cmvn.json",
)
return parser.parse_args()
def main():
args = get_args()
manifest = args.manifest
logging.info(
f"Computing fbank mean and std for {manifest} and saving to {args.cmvn}"
)
assert manifest.is_file(), f"{manifest} does not exist"
cut_set = load_manifest_lazy(manifest)
assert isinstance(cut_set, CutSet), type(cut_set)
feat_dim = cut_set[0].features.num_features
num_frames = 0
s = 0
sq = 0
for c in cut_set:
f = torch.from_numpy(c.load_features())
num_frames += f.shape[0]
s += f.sum()
sq += f.square().sum()
fbank_mean = s / (num_frames * feat_dim)
fbank_var = sq / (num_frames * feat_dim) - fbank_mean * fbank_mean
print("fbank var", fbank_var)
fbank_std = fbank_var.sqrt()
with open(args.cmvn, "w") as f:
json.dump({"fbank_mean": fbank_mean.item(), "fbank_std": fbank_std.item()}, f)
f.write("\n")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -10,6 +10,7 @@ import soundfile as sf
import torch
from matcha.hifigan.config import v1
from matcha.hifigan.denoiser import Denoiser
from tokenizer import Tokenizer
from matcha.hifigan.models import Generator as HiFiGAN
from matcha.text import sequence_to_text, text_to_sequence
from matcha.utils.utils import intersperse
@ -28,7 +29,7 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=140,
default=1320,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
@ -37,13 +38,19 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp",
default="matcha/exp-fbank",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=Path,
default="data/tokens.txt",
)
return parser
@ -71,19 +78,17 @@ def save_to_folder(filename: str, output: dict, folder: str):
sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24")
def process_text(text: str):
x = torch.tensor(
intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0),
dtype=torch.long,
device="cpu",
)[None]
def process_text(text: str, tokenizer):
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
x = torch.tensor(x, dtype=torch.long)
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
x_phones = sequence_to_text(x.squeeze(0).tolist())
return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones}
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
def synthesise(model, n_timesteps, text, length_scale, temperature, spks=None):
text_processed = process_text(text)
def synthesise(
model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None
):
text_processed = process_text(text, tokenizer)
start_t = dt.datetime.now()
output = model.synthesise(
text_processed["x"],
@ -108,6 +113,11 @@ def main():
params.update(vars(args))
logging.info(params)
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
@ -117,12 +127,13 @@ def main():
denoiser = Denoiser(vocoder, mode="zeros")
texts = [
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
"Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
"How are you doing, my friend",
# "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
# "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
]
# Number of ODE Solver steps
n_timesteps = 2
n_timesteps = 3
# Changes to the speaking rate
length_scale = 1.0
@ -135,6 +146,7 @@ def main():
for i, text in enumerate(tqdm(texts)):
output = synthesise(
model=model,
tokenizer=tokenizer,
n_timesteps=n_timesteps,
text=text,
length_scale=length_scale,
@ -154,7 +166,7 @@ def main():
print(f"{'*' * 53}")
print(f"Phonetised text - {i}")
print(f"{'-' * 53}")
print(output["x_phones"])
print(output["x"])
print(f"{'*' * 53}")
print(f"RTF:\t\t{output['rtf']:.6f}")
print(f"RTF Waveform:\t{rtf_w:.6f}")
@ -162,7 +174,7 @@ def main():
rtfs_w.append(rtf_w)
# Save the generated waveform
save_to_folder(i, output, folder="./my-output")
save_to_folder(i, output, folder="./my-output-1320")
print(f"Number of ODE steps: {n_timesteps}")
print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}")

View File

@ -13,6 +13,7 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from lhotse.utils import fix_random_seed
from matcha.data.text_mel_datamodule import TextMelDataModule
from matcha.models.matcha_tts import MatchaTTS
from matcha.tokenizer import Tokenizer
from matcha.utils.model import fix_len_compatibility
@ -122,8 +123,11 @@ def get_parser():
def get_data_statistics():
return AttributeDict(
{
"mel_mean": 0.0,
"mel_std": 1.0,
# "mel_mean": -5.517028331756592, # matcha-tts
# "mel_std": 2.0643954277038574,
# ours
"mel_mean": -1.168782114982605,
"mel_std": 1.9283572435379028,
}
)
@ -134,7 +138,8 @@ def _get_data_params() -> AttributeDict:
"name": "ljspeech",
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
"num_workers": 3,
"batch_size": 64,
"num_workers": 1,
"pin_memory": False,
"cleaners": ["english_cleaners2"],
"add_blank": True,
@ -289,8 +294,17 @@ def load_checkpoint_if_available(
return saved_params
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, params):
"""Parse batch data"""
mel_mean = params.data_args.data_statistics.mel_mean
mel_std_inv = 1 / params.data_args.data_statistics.mel_std
for i in range(batch["features"].shape[0]):
n = batch["features_lens"][i]
batch["features"][i : i + 1, :n, :] = (
batch["features"][i : i + 1, :n, :] - mel_mean
) * mel_std_inv
batch["features"][i : i + 1, n:, :] = 0
audio = batch["audio"].to(device)
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
@ -298,7 +312,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(
tokens, intersperse_blank=True, add_sos=True, add_eos=True
tokens, intersperse_blank=True, add_sos=False, add_eos=False
)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
@ -315,7 +329,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
# features_lens[features_lens.argmax()] += pad
return audio, audio_lens, features, features_lens, tokens, tokens_lens
return audio, audio_lens, features, features_lens.long(), tokens, tokens_lens.long()
def compute_validation_loss(
@ -336,28 +350,36 @@ def compute_validation_loss(
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
if "tokens" in batch:
(
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
) = prepare_input(batch, tokenizer, device)
(
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
) = prepare_input(batch, tokenizer, device, params)
losses = get_losses(
{
"x": tokens,
"x_lengths": tokens_lens,
"y": features.permute(0, 2, 1),
"y_lengths": features_lens,
"spks": None, # should change it for multi-speakers
"durations": None,
}
)
losses = get_losses(
{
"x": tokens,
"x_lengths": tokens_lens,
"y": features.permute(0, 2, 1),
"y_lengths": features_lens,
"spks": None, # should change it for multi-speakers
"durations": None,
}
)
batch_size = len(batch["tokens"])
batch_size = len(batch["tokens"])
else:
batch_size = batch["x"].shape[0]
batch["x"] = batch["x"].to(device)
batch["x_lengths"] = batch["x_lengths"].to(device)
batch["y"] = batch["y"].to(device)
batch["y_lengths"] = batch["y_lengths"].to(device)
losses = get_losses(batch)
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
@ -451,24 +473,38 @@ def train_one_epoch(
# features_lens, (N,), int32
# tokens: List[List[str]], len(tokens) == N
batch_size = len(batch["tokens"])
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
batch, tokenizer, device
)
if "tokens" in batch:
batch_size = len(batch["tokens"])
(
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
) = prepare_input(batch, tokenizer, device, params)
else:
batch_size = batch["x"].shape[0]
try:
with autocast(enabled=params.use_fp16):
losses = get_losses(
{
"x": tokens,
"x_lengths": tokens_lens,
"y": features.permute(0, 2, 1),
"y_lengths": features_lens,
"spks": None, # should change it for multi-speakers
"durations": None,
}
)
if "tokens" in batch:
losses = get_losses(
{
"x": tokens,
"x_lengths": tokens_lens,
"y": features.permute(0, 2, 1),
"y_lengths": features_lens,
"spks": None, # should change it for multi-speakers
"durations": None,
}
)
else:
batch["x"] = batch["x"].to(device)
batch["x_lengths"] = batch["x_lengths"].to(device)
batch["y"] = batch["y"].to(device)
batch["y_lengths"] = batch["y_lengths"].to(device)
losses = get_losses(batch)
loss = sum(losses.values())
@ -586,6 +622,7 @@ def run(rank, world_size, args):
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
params.model_args.n_vocab = 178
logging.info(params)
print(params)
@ -595,7 +632,6 @@ def run(rank, world_size, args):
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of parameters: {num_param}")
print(f"Number of parameters: {num_param}")
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(params=params, model=model)
@ -609,13 +645,21 @@ def run(rank, world_size, args):
optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer)
logging.info("About to create datamodule")
ljspeech = LJSpeechTtsDataModule(args)
train_cuts = ljspeech.train_cuts()
train_dl = ljspeech.train_dataloaders(train_cuts)
if False:
params.data_args.tokenizer = tokenizer
data_module = TextMelDataModule(hparams=params.data_args)
del params.data_args.tokenizer
train_dl = data_module.train_dataloader()
valid_dl = data_module.val_dataloader()
else:
ljspeech = LJSpeechTtsDataModule(args)
valid_cuts = ljspeech.valid_cuts()
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
train_cuts = ljspeech.train_cuts()
train_dl = ljspeech.train_dataloaders(train_cuts)
valid_cuts = ljspeech.valid_cuts()
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
@ -625,7 +669,8 @@ def run(rank, world_size, args):
for epoch in range(params.start_epoch, params.num_epochs + 1):
logging.info(f"Start epoch {epoch}")
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
if "sampler" in train_dl:
train_dl.sampler.set_epoch(epoch - 1)
params.cur_epoch = epoch

View File

@ -181,6 +181,8 @@ class LJSpeechTtsDataModule:
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
remove_dc_offset=False,
preemph_coeff=0,
low_freq=0,
high_freq=8000,
# should be identical to n_feats in ./train.py
@ -242,6 +244,8 @@ class LJSpeechTtsDataModule:
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
remove_dc_offset=False,
preemph_coeff=0,
low_freq=0,
high_freq=8000,
# should be identical to n_feats in ./train.py
@ -286,6 +290,8 @@ class LJSpeechTtsDataModule:
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
remove_dc_offset=False,
preemph_coeff=0,
low_freq=0,
high_freq=8000,
# should be identical to n_feats in ./train.py

View File

@ -191,3 +191,10 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
touch data/fbank/.ljspeech_split.done
fi
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Compute fbank mean and std (used by ./matcha)"
if [ ! -f ./data/fbank/cmvn.json ]; then
./local/compute_fbank_statistics.py ./data/fbank/ljspeech_cuts_train.jsonl.gz ./data/fbank/cmvn.json
fi
fi