mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
use CMVN
This commit is contained in:
parent
7077b4f99a
commit
6a4cb112dd
@ -106,6 +106,8 @@ def compute_fbank_ljspeech(num_jobs: int):
|
||||
use_fft_mag=True,
|
||||
low_freq=0,
|
||||
high_freq=8000,
|
||||
remove_dc_offset=False,
|
||||
preemph_coeff=0,
|
||||
# should be identical to n_feats in ../matcha/train.py
|
||||
num_filters=80,
|
||||
)
|
||||
|
84
egs/ljspeech/TTS/local/compute_fbank_statistics.py
Executable file
84
egs/ljspeech/TTS/local/compute_fbank_statistics.py
Executable file
@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script compute the mean and std of the fbank features.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"manifest",
|
||||
type=Path,
|
||||
help="Path to the manifest file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"cmvn",
|
||||
type=Path,
|
||||
help="Path to the cmvn.json",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
manifest = args.manifest
|
||||
logging.info(
|
||||
f"Computing fbank mean and std for {manifest} and saving to {args.cmvn}"
|
||||
)
|
||||
|
||||
assert manifest.is_file(), f"{manifest} does not exist"
|
||||
cut_set = load_manifest_lazy(manifest)
|
||||
assert isinstance(cut_set, CutSet), type(cut_set)
|
||||
|
||||
feat_dim = cut_set[0].features.num_features
|
||||
num_frames = 0
|
||||
s = 0
|
||||
sq = 0
|
||||
for c in cut_set:
|
||||
f = torch.from_numpy(c.load_features())
|
||||
num_frames += f.shape[0]
|
||||
s += f.sum()
|
||||
sq += f.square().sum()
|
||||
|
||||
fbank_mean = s / (num_frames * feat_dim)
|
||||
fbank_var = sq / (num_frames * feat_dim) - fbank_mean * fbank_mean
|
||||
print("fbank var", fbank_var)
|
||||
fbank_std = fbank_var.sqrt()
|
||||
with open(args.cmvn, "w") as f:
|
||||
json.dump({"fbank_mean": fbank_mean.item(), "fbank_std": fbank_std.item()}, f)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
@ -10,6 +10,7 @@ import soundfile as sf
|
||||
import torch
|
||||
from matcha.hifigan.config import v1
|
||||
from matcha.hifigan.denoiser import Denoiser
|
||||
from tokenizer import Tokenizer
|
||||
from matcha.hifigan.models import Generator as HiFiGAN
|
||||
from matcha.text import sequence_to_text, text_to_sequence
|
||||
from matcha.utils.utils import intersperse
|
||||
@ -28,7 +29,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=140,
|
||||
default=1320,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
""",
|
||||
@ -37,13 +38,19 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=Path,
|
||||
default="matcha/exp",
|
||||
default="matcha/exp-fbank",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=Path,
|
||||
default="data/tokens.txt",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -71,19 +78,17 @@ def save_to_folder(filename: str, output: dict, folder: str):
|
||||
sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24")
|
||||
|
||||
|
||||
def process_text(text: str):
|
||||
x = torch.tensor(
|
||||
intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0),
|
||||
dtype=torch.long,
|
||||
device="cpu",
|
||||
)[None]
|
||||
def process_text(text: str, tokenizer):
|
||||
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
||||
x = torch.tensor(x, dtype=torch.long)
|
||||
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
|
||||
x_phones = sequence_to_text(x.squeeze(0).tolist())
|
||||
return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones}
|
||||
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
|
||||
|
||||
|
||||
def synthesise(model, n_timesteps, text, length_scale, temperature, spks=None):
|
||||
text_processed = process_text(text)
|
||||
def synthesise(
|
||||
model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None
|
||||
):
|
||||
text_processed = process_text(text, tokenizer)
|
||||
start_t = dt.datetime.now()
|
||||
output = model.synthesise(
|
||||
text_processed["x"],
|
||||
@ -108,6 +113,11 @@ def main():
|
||||
params.update(vars(args))
|
||||
logging.info(params)
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.model_args.n_vocab = params.vocab_size
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
@ -117,12 +127,13 @@ def main():
|
||||
denoiser = Denoiser(vocoder, mode="zeros")
|
||||
|
||||
texts = [
|
||||
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
|
||||
"Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
|
||||
"How are you doing, my friend",
|
||||
# "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
|
||||
# "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
|
||||
]
|
||||
|
||||
# Number of ODE Solver steps
|
||||
n_timesteps = 2
|
||||
n_timesteps = 3
|
||||
|
||||
# Changes to the speaking rate
|
||||
length_scale = 1.0
|
||||
@ -135,6 +146,7 @@ def main():
|
||||
for i, text in enumerate(tqdm(texts)):
|
||||
output = synthesise(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
n_timesteps=n_timesteps,
|
||||
text=text,
|
||||
length_scale=length_scale,
|
||||
@ -154,7 +166,7 @@ def main():
|
||||
print(f"{'*' * 53}")
|
||||
print(f"Phonetised text - {i}")
|
||||
print(f"{'-' * 53}")
|
||||
print(output["x_phones"])
|
||||
print(output["x"])
|
||||
print(f"{'*' * 53}")
|
||||
print(f"RTF:\t\t{output['rtf']:.6f}")
|
||||
print(f"RTF Waveform:\t{rtf_w:.6f}")
|
||||
@ -162,7 +174,7 @@ def main():
|
||||
rtfs_w.append(rtf_w)
|
||||
|
||||
# Save the generated waveform
|
||||
save_to_folder(i, output, folder="./my-output")
|
||||
save_to_folder(i, output, folder="./my-output-1320")
|
||||
|
||||
print(f"Number of ODE steps: {n_timesteps}")
|
||||
print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}")
|
||||
|
@ -13,6 +13,7 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from lhotse.utils import fix_random_seed
|
||||
from matcha.data.text_mel_datamodule import TextMelDataModule
|
||||
from matcha.models.matcha_tts import MatchaTTS
|
||||
from matcha.tokenizer import Tokenizer
|
||||
from matcha.utils.model import fix_len_compatibility
|
||||
@ -122,8 +123,11 @@ def get_parser():
|
||||
def get_data_statistics():
|
||||
return AttributeDict(
|
||||
{
|
||||
"mel_mean": 0.0,
|
||||
"mel_std": 1.0,
|
||||
# "mel_mean": -5.517028331756592, # matcha-tts
|
||||
# "mel_std": 2.0643954277038574,
|
||||
# ours
|
||||
"mel_mean": -1.168782114982605,
|
||||
"mel_std": 1.9283572435379028,
|
||||
}
|
||||
)
|
||||
|
||||
@ -134,7 +138,8 @@ def _get_data_params() -> AttributeDict:
|
||||
"name": "ljspeech",
|
||||
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
|
||||
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
|
||||
"num_workers": 3,
|
||||
"batch_size": 64,
|
||||
"num_workers": 1,
|
||||
"pin_memory": False,
|
||||
"cleaners": ["english_cleaners2"],
|
||||
"add_blank": True,
|
||||
@ -289,8 +294,17 @@ def load_checkpoint_if_available(
|
||||
return saved_params
|
||||
|
||||
|
||||
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
||||
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, params):
|
||||
"""Parse batch data"""
|
||||
mel_mean = params.data_args.data_statistics.mel_mean
|
||||
mel_std_inv = 1 / params.data_args.data_statistics.mel_std
|
||||
for i in range(batch["features"].shape[0]):
|
||||
n = batch["features_lens"][i]
|
||||
batch["features"][i : i + 1, :n, :] = (
|
||||
batch["features"][i : i + 1, :n, :] - mel_mean
|
||||
) * mel_std_inv
|
||||
batch["features"][i : i + 1, n:, :] = 0
|
||||
|
||||
audio = batch["audio"].to(device)
|
||||
features = batch["features"].to(device)
|
||||
audio_lens = batch["audio_lens"].to(device)
|
||||
@ -298,7 +312,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
||||
tokens = batch["tokens"]
|
||||
|
||||
tokens = tokenizer.tokens_to_token_ids(
|
||||
tokens, intersperse_blank=True, add_sos=True, add_eos=True
|
||||
tokens, intersperse_blank=True, add_sos=False, add_eos=False
|
||||
)
|
||||
tokens = k2.RaggedTensor(tokens)
|
||||
row_splits = tokens.shape.row_splits(1)
|
||||
@ -315,7 +329,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
||||
|
||||
# features_lens[features_lens.argmax()] += pad
|
||||
|
||||
return audio, audio_lens, features, features_lens, tokens, tokens_lens
|
||||
return audio, audio_lens, features, features_lens.long(), tokens, tokens_lens.long()
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
@ -336,28 +350,36 @@ def compute_validation_loss(
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
if "tokens" in batch:
|
||||
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
features,
|
||||
features_lens,
|
||||
tokens,
|
||||
tokens_lens,
|
||||
) = prepare_input(batch, tokenizer, device)
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
features,
|
||||
features_lens,
|
||||
tokens,
|
||||
tokens_lens,
|
||||
) = prepare_input(batch, tokenizer, device, params)
|
||||
|
||||
losses = get_losses(
|
||||
{
|
||||
"x": tokens,
|
||||
"x_lengths": tokens_lens,
|
||||
"y": features.permute(0, 2, 1),
|
||||
"y_lengths": features_lens,
|
||||
"spks": None, # should change it for multi-speakers
|
||||
"durations": None,
|
||||
}
|
||||
)
|
||||
losses = get_losses(
|
||||
{
|
||||
"x": tokens,
|
||||
"x_lengths": tokens_lens,
|
||||
"y": features.permute(0, 2, 1),
|
||||
"y_lengths": features_lens,
|
||||
"spks": None, # should change it for multi-speakers
|
||||
"durations": None,
|
||||
}
|
||||
)
|
||||
|
||||
batch_size = len(batch["tokens"])
|
||||
batch_size = len(batch["tokens"])
|
||||
else:
|
||||
batch_size = batch["x"].shape[0]
|
||||
batch["x"] = batch["x"].to(device)
|
||||
batch["x_lengths"] = batch["x_lengths"].to(device)
|
||||
batch["y"] = batch["y"].to(device)
|
||||
batch["y_lengths"] = batch["y_lengths"].to(device)
|
||||
losses = get_losses(batch)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
@ -451,24 +473,38 @@ def train_one_epoch(
|
||||
# features_lens, (N,), int32
|
||||
# tokens: List[List[str]], len(tokens) == N
|
||||
|
||||
batch_size = len(batch["tokens"])
|
||||
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
||||
batch, tokenizer, device
|
||||
)
|
||||
if "tokens" in batch:
|
||||
batch_size = len(batch["tokens"])
|
||||
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
features,
|
||||
features_lens,
|
||||
tokens,
|
||||
tokens_lens,
|
||||
) = prepare_input(batch, tokenizer, device, params)
|
||||
else:
|
||||
batch_size = batch["x"].shape[0]
|
||||
try:
|
||||
with autocast(enabled=params.use_fp16):
|
||||
losses = get_losses(
|
||||
{
|
||||
"x": tokens,
|
||||
"x_lengths": tokens_lens,
|
||||
"y": features.permute(0, 2, 1),
|
||||
"y_lengths": features_lens,
|
||||
"spks": None, # should change it for multi-speakers
|
||||
"durations": None,
|
||||
}
|
||||
)
|
||||
if "tokens" in batch:
|
||||
losses = get_losses(
|
||||
{
|
||||
"x": tokens,
|
||||
"x_lengths": tokens_lens,
|
||||
"y": features.permute(0, 2, 1),
|
||||
"y_lengths": features_lens,
|
||||
"spks": None, # should change it for multi-speakers
|
||||
"durations": None,
|
||||
}
|
||||
)
|
||||
else:
|
||||
batch["x"] = batch["x"].to(device)
|
||||
batch["x_lengths"] = batch["x_lengths"].to(device)
|
||||
batch["y"] = batch["y"].to(device)
|
||||
batch["y_lengths"] = batch["y_lengths"].to(device)
|
||||
losses = get_losses(batch)
|
||||
|
||||
loss = sum(losses.values())
|
||||
|
||||
@ -586,6 +622,7 @@ def run(rank, world_size, args):
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.model_args.n_vocab = params.vocab_size
|
||||
params.model_args.n_vocab = 178
|
||||
|
||||
logging.info(params)
|
||||
print(params)
|
||||
@ -595,7 +632,6 @@ def run(rank, world_size, args):
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of parameters: {num_param}")
|
||||
print(f"Number of parameters: {num_param}")
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
@ -609,13 +645,21 @@ def run(rank, world_size, args):
|
||||
optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer)
|
||||
|
||||
logging.info("About to create datamodule")
|
||||
ljspeech = LJSpeechTtsDataModule(args)
|
||||
|
||||
train_cuts = ljspeech.train_cuts()
|
||||
train_dl = ljspeech.train_dataloaders(train_cuts)
|
||||
if False:
|
||||
params.data_args.tokenizer = tokenizer
|
||||
data_module = TextMelDataModule(hparams=params.data_args)
|
||||
del params.data_args.tokenizer
|
||||
train_dl = data_module.train_dataloader()
|
||||
valid_dl = data_module.val_dataloader()
|
||||
else:
|
||||
ljspeech = LJSpeechTtsDataModule(args)
|
||||
|
||||
valid_cuts = ljspeech.valid_cuts()
|
||||
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
|
||||
train_cuts = ljspeech.train_cuts()
|
||||
train_dl = ljspeech.train_dataloaders(train_cuts)
|
||||
|
||||
valid_cuts = ljspeech.valid_cuts()
|
||||
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
@ -625,7 +669,8 @@ def run(rank, world_size, args):
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
logging.info(f"Start epoch {epoch}")
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
if "sampler" in train_dl:
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
|
@ -181,6 +181,8 @@ class LJSpeechTtsDataModule:
|
||||
frame_length=1024 / sampling_rate, # (in second),
|
||||
frame_shift=256 / sampling_rate, # (in second)
|
||||
use_fft_mag=True,
|
||||
remove_dc_offset=False,
|
||||
preemph_coeff=0,
|
||||
low_freq=0,
|
||||
high_freq=8000,
|
||||
# should be identical to n_feats in ./train.py
|
||||
@ -242,6 +244,8 @@ class LJSpeechTtsDataModule:
|
||||
frame_length=1024 / sampling_rate, # (in second),
|
||||
frame_shift=256 / sampling_rate, # (in second)
|
||||
use_fft_mag=True,
|
||||
remove_dc_offset=False,
|
||||
preemph_coeff=0,
|
||||
low_freq=0,
|
||||
high_freq=8000,
|
||||
# should be identical to n_feats in ./train.py
|
||||
@ -286,6 +290,8 @@ class LJSpeechTtsDataModule:
|
||||
frame_length=1024 / sampling_rate, # (in second),
|
||||
frame_shift=256 / sampling_rate, # (in second)
|
||||
use_fft_mag=True,
|
||||
remove_dc_offset=False,
|
||||
preemph_coeff=0,
|
||||
low_freq=0,
|
||||
high_freq=8000,
|
||||
# should be identical to n_feats in ./train.py
|
||||
|
@ -191,3 +191,10 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
touch data/fbank/.ljspeech_split.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
log "Stage 9: Compute fbank mean and std (used by ./matcha)"
|
||||
if [ ! -f ./data/fbank/cmvn.json ]; then
|
||||
./local/compute_fbank_statistics.py ./data/fbank/ljspeech_cuts_train.jsonl.gz ./data/fbank/cmvn.json
|
||||
fi
|
||||
fi
|
||||
|
Loading…
x
Reference in New Issue
Block a user