mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
use CMVN
This commit is contained in:
parent
7077b4f99a
commit
6a4cb112dd
@ -106,6 +106,8 @@ def compute_fbank_ljspeech(num_jobs: int):
|
|||||||
use_fft_mag=True,
|
use_fft_mag=True,
|
||||||
low_freq=0,
|
low_freq=0,
|
||||||
high_freq=8000,
|
high_freq=8000,
|
||||||
|
remove_dc_offset=False,
|
||||||
|
preemph_coeff=0,
|
||||||
# should be identical to n_feats in ../matcha/train.py
|
# should be identical to n_feats in ../matcha/train.py
|
||||||
num_filters=80,
|
num_filters=80,
|
||||||
)
|
)
|
||||||
|
84
egs/ljspeech/TTS/local/compute_fbank_statistics.py
Executable file
84
egs/ljspeech/TTS/local/compute_fbank_statistics.py
Executable file
@ -0,0 +1,84 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script compute the mean and std of the fbank features.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import CutSet, load_manifest_lazy
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"manifest",
|
||||||
|
type=Path,
|
||||||
|
help="Path to the manifest file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"cmvn",
|
||||||
|
type=Path,
|
||||||
|
help="Path to the cmvn.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
|
||||||
|
manifest = args.manifest
|
||||||
|
logging.info(
|
||||||
|
f"Computing fbank mean and std for {manifest} and saving to {args.cmvn}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert manifest.is_file(), f"{manifest} does not exist"
|
||||||
|
cut_set = load_manifest_lazy(manifest)
|
||||||
|
assert isinstance(cut_set, CutSet), type(cut_set)
|
||||||
|
|
||||||
|
feat_dim = cut_set[0].features.num_features
|
||||||
|
num_frames = 0
|
||||||
|
s = 0
|
||||||
|
sq = 0
|
||||||
|
for c in cut_set:
|
||||||
|
f = torch.from_numpy(c.load_features())
|
||||||
|
num_frames += f.shape[0]
|
||||||
|
s += f.sum()
|
||||||
|
sq += f.square().sum()
|
||||||
|
|
||||||
|
fbank_mean = s / (num_frames * feat_dim)
|
||||||
|
fbank_var = sq / (num_frames * feat_dim) - fbank_mean * fbank_mean
|
||||||
|
print("fbank var", fbank_var)
|
||||||
|
fbank_std = fbank_var.sqrt()
|
||||||
|
with open(args.cmvn, "w") as f:
|
||||||
|
json.dump({"fbank_mean": fbank_mean.item(), "fbank_std": fbank_std.item()}, f)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
@ -10,6 +10,7 @@ import soundfile as sf
|
|||||||
import torch
|
import torch
|
||||||
from matcha.hifigan.config import v1
|
from matcha.hifigan.config import v1
|
||||||
from matcha.hifigan.denoiser import Denoiser
|
from matcha.hifigan.denoiser import Denoiser
|
||||||
|
from tokenizer import Tokenizer
|
||||||
from matcha.hifigan.models import Generator as HiFiGAN
|
from matcha.hifigan.models import Generator as HiFiGAN
|
||||||
from matcha.text import sequence_to_text, text_to_sequence
|
from matcha.text import sequence_to_text, text_to_sequence
|
||||||
from matcha.utils.utils import intersperse
|
from matcha.utils.utils import intersperse
|
||||||
@ -28,7 +29,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=140,
|
default=1320,
|
||||||
help="""It specifies the checkpoint to use for decoding.
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
Note: Epoch counts from 1.
|
Note: Epoch counts from 1.
|
||||||
""",
|
""",
|
||||||
@ -37,13 +38,19 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
default="matcha/exp",
|
default="matcha/exp-fbank",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=Path,
|
||||||
|
default="data/tokens.txt",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -71,19 +78,17 @@ def save_to_folder(filename: str, output: dict, folder: str):
|
|||||||
sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24")
|
sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24")
|
||||||
|
|
||||||
|
|
||||||
def process_text(text: str):
|
def process_text(text: str, tokenizer):
|
||||||
x = torch.tensor(
|
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
||||||
intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0),
|
x = torch.tensor(x, dtype=torch.long)
|
||||||
dtype=torch.long,
|
|
||||||
device="cpu",
|
|
||||||
)[None]
|
|
||||||
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
|
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
|
||||||
x_phones = sequence_to_text(x.squeeze(0).tolist())
|
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
|
||||||
return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones}
|
|
||||||
|
|
||||||
|
|
||||||
def synthesise(model, n_timesteps, text, length_scale, temperature, spks=None):
|
def synthesise(
|
||||||
text_processed = process_text(text)
|
model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None
|
||||||
|
):
|
||||||
|
text_processed = process_text(text, tokenizer)
|
||||||
start_t = dt.datetime.now()
|
start_t = dt.datetime.now()
|
||||||
output = model.synthesise(
|
output = model.synthesise(
|
||||||
text_processed["x"],
|
text_processed["x"],
|
||||||
@ -108,6 +113,11 @@ def main():
|
|||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(params.tokens)
|
||||||
|
params.blank_id = tokenizer.pad_id
|
||||||
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
params.model_args.n_vocab = params.vocab_size
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
@ -117,12 +127,13 @@ def main():
|
|||||||
denoiser = Denoiser(vocoder, mode="zeros")
|
denoiser = Denoiser(vocoder, mode="zeros")
|
||||||
|
|
||||||
texts = [
|
texts = [
|
||||||
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
|
"How are you doing, my friend",
|
||||||
"Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
|
# "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
|
||||||
|
# "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Number of ODE Solver steps
|
# Number of ODE Solver steps
|
||||||
n_timesteps = 2
|
n_timesteps = 3
|
||||||
|
|
||||||
# Changes to the speaking rate
|
# Changes to the speaking rate
|
||||||
length_scale = 1.0
|
length_scale = 1.0
|
||||||
@ -135,6 +146,7 @@ def main():
|
|||||||
for i, text in enumerate(tqdm(texts)):
|
for i, text in enumerate(tqdm(texts)):
|
||||||
output = synthesise(
|
output = synthesise(
|
||||||
model=model,
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
n_timesteps=n_timesteps,
|
n_timesteps=n_timesteps,
|
||||||
text=text,
|
text=text,
|
||||||
length_scale=length_scale,
|
length_scale=length_scale,
|
||||||
@ -154,7 +166,7 @@ def main():
|
|||||||
print(f"{'*' * 53}")
|
print(f"{'*' * 53}")
|
||||||
print(f"Phonetised text - {i}")
|
print(f"Phonetised text - {i}")
|
||||||
print(f"{'-' * 53}")
|
print(f"{'-' * 53}")
|
||||||
print(output["x_phones"])
|
print(output["x"])
|
||||||
print(f"{'*' * 53}")
|
print(f"{'*' * 53}")
|
||||||
print(f"RTF:\t\t{output['rtf']:.6f}")
|
print(f"RTF:\t\t{output['rtf']:.6f}")
|
||||||
print(f"RTF Waveform:\t{rtf_w:.6f}")
|
print(f"RTF Waveform:\t{rtf_w:.6f}")
|
||||||
@ -162,7 +174,7 @@ def main():
|
|||||||
rtfs_w.append(rtf_w)
|
rtfs_w.append(rtf_w)
|
||||||
|
|
||||||
# Save the generated waveform
|
# Save the generated waveform
|
||||||
save_to_folder(i, output, folder="./my-output")
|
save_to_folder(i, output, folder="./my-output-1320")
|
||||||
|
|
||||||
print(f"Number of ODE steps: {n_timesteps}")
|
print(f"Number of ODE steps: {n_timesteps}")
|
||||||
print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}")
|
print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}")
|
||||||
|
@ -13,6 +13,7 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
|
from matcha.data.text_mel_datamodule import TextMelDataModule
|
||||||
from matcha.models.matcha_tts import MatchaTTS
|
from matcha.models.matcha_tts import MatchaTTS
|
||||||
from matcha.tokenizer import Tokenizer
|
from matcha.tokenizer import Tokenizer
|
||||||
from matcha.utils.model import fix_len_compatibility
|
from matcha.utils.model import fix_len_compatibility
|
||||||
@ -122,8 +123,11 @@ def get_parser():
|
|||||||
def get_data_statistics():
|
def get_data_statistics():
|
||||||
return AttributeDict(
|
return AttributeDict(
|
||||||
{
|
{
|
||||||
"mel_mean": 0.0,
|
# "mel_mean": -5.517028331756592, # matcha-tts
|
||||||
"mel_std": 1.0,
|
# "mel_std": 2.0643954277038574,
|
||||||
|
# ours
|
||||||
|
"mel_mean": -1.168782114982605,
|
||||||
|
"mel_std": 1.9283572435379028,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -134,7 +138,8 @@ def _get_data_params() -> AttributeDict:
|
|||||||
"name": "ljspeech",
|
"name": "ljspeech",
|
||||||
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
|
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
|
||||||
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
|
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
|
||||||
"num_workers": 3,
|
"batch_size": 64,
|
||||||
|
"num_workers": 1,
|
||||||
"pin_memory": False,
|
"pin_memory": False,
|
||||||
"cleaners": ["english_cleaners2"],
|
"cleaners": ["english_cleaners2"],
|
||||||
"add_blank": True,
|
"add_blank": True,
|
||||||
@ -289,8 +294,17 @@ def load_checkpoint_if_available(
|
|||||||
return saved_params
|
return saved_params
|
||||||
|
|
||||||
|
|
||||||
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, params):
|
||||||
"""Parse batch data"""
|
"""Parse batch data"""
|
||||||
|
mel_mean = params.data_args.data_statistics.mel_mean
|
||||||
|
mel_std_inv = 1 / params.data_args.data_statistics.mel_std
|
||||||
|
for i in range(batch["features"].shape[0]):
|
||||||
|
n = batch["features_lens"][i]
|
||||||
|
batch["features"][i : i + 1, :n, :] = (
|
||||||
|
batch["features"][i : i + 1, :n, :] - mel_mean
|
||||||
|
) * mel_std_inv
|
||||||
|
batch["features"][i : i + 1, n:, :] = 0
|
||||||
|
|
||||||
audio = batch["audio"].to(device)
|
audio = batch["audio"].to(device)
|
||||||
features = batch["features"].to(device)
|
features = batch["features"].to(device)
|
||||||
audio_lens = batch["audio_lens"].to(device)
|
audio_lens = batch["audio_lens"].to(device)
|
||||||
@ -298,7 +312,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
|||||||
tokens = batch["tokens"]
|
tokens = batch["tokens"]
|
||||||
|
|
||||||
tokens = tokenizer.tokens_to_token_ids(
|
tokens = tokenizer.tokens_to_token_ids(
|
||||||
tokens, intersperse_blank=True, add_sos=True, add_eos=True
|
tokens, intersperse_blank=True, add_sos=False, add_eos=False
|
||||||
)
|
)
|
||||||
tokens = k2.RaggedTensor(tokens)
|
tokens = k2.RaggedTensor(tokens)
|
||||||
row_splits = tokens.shape.row_splits(1)
|
row_splits = tokens.shape.row_splits(1)
|
||||||
@ -315,7 +329,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
|||||||
|
|
||||||
# features_lens[features_lens.argmax()] += pad
|
# features_lens[features_lens.argmax()] += pad
|
||||||
|
|
||||||
return audio, audio_lens, features, features_lens, tokens, tokens_lens
|
return audio, audio_lens, features, features_lens.long(), tokens, tokens_lens.long()
|
||||||
|
|
||||||
|
|
||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
@ -336,28 +350,36 @@ def compute_validation_loss(
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
|
if "tokens" in batch:
|
||||||
|
|
||||||
(
|
(
|
||||||
audio,
|
audio,
|
||||||
audio_lens,
|
audio_lens,
|
||||||
features,
|
features,
|
||||||
features_lens,
|
features_lens,
|
||||||
tokens,
|
tokens,
|
||||||
tokens_lens,
|
tokens_lens,
|
||||||
) = prepare_input(batch, tokenizer, device)
|
) = prepare_input(batch, tokenizer, device, params)
|
||||||
|
|
||||||
losses = get_losses(
|
losses = get_losses(
|
||||||
{
|
{
|
||||||
"x": tokens,
|
"x": tokens,
|
||||||
"x_lengths": tokens_lens,
|
"x_lengths": tokens_lens,
|
||||||
"y": features.permute(0, 2, 1),
|
"y": features.permute(0, 2, 1),
|
||||||
"y_lengths": features_lens,
|
"y_lengths": features_lens,
|
||||||
"spks": None, # should change it for multi-speakers
|
"spks": None, # should change it for multi-speakers
|
||||||
"durations": None,
|
"durations": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_size = len(batch["tokens"])
|
batch_size = len(batch["tokens"])
|
||||||
|
else:
|
||||||
|
batch_size = batch["x"].shape[0]
|
||||||
|
batch["x"] = batch["x"].to(device)
|
||||||
|
batch["x_lengths"] = batch["x_lengths"].to(device)
|
||||||
|
batch["y"] = batch["y"].to(device)
|
||||||
|
batch["y_lengths"] = batch["y_lengths"].to(device)
|
||||||
|
losses = get_losses(batch)
|
||||||
|
|
||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
loss_info["samples"] = batch_size
|
loss_info["samples"] = batch_size
|
||||||
@ -451,24 +473,38 @@ def train_one_epoch(
|
|||||||
# features_lens, (N,), int32
|
# features_lens, (N,), int32
|
||||||
# tokens: List[List[str]], len(tokens) == N
|
# tokens: List[List[str]], len(tokens) == N
|
||||||
|
|
||||||
batch_size = len(batch["tokens"])
|
if "tokens" in batch:
|
||||||
|
batch_size = len(batch["tokens"])
|
||||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
|
||||||
batch, tokenizer, device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
(
|
||||||
|
audio,
|
||||||
|
audio_lens,
|
||||||
|
features,
|
||||||
|
features_lens,
|
||||||
|
tokens,
|
||||||
|
tokens_lens,
|
||||||
|
) = prepare_input(batch, tokenizer, device, params)
|
||||||
|
else:
|
||||||
|
batch_size = batch["x"].shape[0]
|
||||||
try:
|
try:
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
losses = get_losses(
|
if "tokens" in batch:
|
||||||
{
|
losses = get_losses(
|
||||||
"x": tokens,
|
{
|
||||||
"x_lengths": tokens_lens,
|
"x": tokens,
|
||||||
"y": features.permute(0, 2, 1),
|
"x_lengths": tokens_lens,
|
||||||
"y_lengths": features_lens,
|
"y": features.permute(0, 2, 1),
|
||||||
"spks": None, # should change it for multi-speakers
|
"y_lengths": features_lens,
|
||||||
"durations": None,
|
"spks": None, # should change it for multi-speakers
|
||||||
}
|
"durations": None,
|
||||||
)
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch["x"] = batch["x"].to(device)
|
||||||
|
batch["x_lengths"] = batch["x_lengths"].to(device)
|
||||||
|
batch["y"] = batch["y"].to(device)
|
||||||
|
batch["y_lengths"] = batch["y_lengths"].to(device)
|
||||||
|
losses = get_losses(batch)
|
||||||
|
|
||||||
loss = sum(losses.values())
|
loss = sum(losses.values())
|
||||||
|
|
||||||
@ -586,6 +622,7 @@ def run(rank, world_size, args):
|
|||||||
params.blank_id = tokenizer.pad_id
|
params.blank_id = tokenizer.pad_id
|
||||||
params.vocab_size = tokenizer.vocab_size
|
params.vocab_size = tokenizer.vocab_size
|
||||||
params.model_args.n_vocab = params.vocab_size
|
params.model_args.n_vocab = params.vocab_size
|
||||||
|
params.model_args.n_vocab = 178
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
print(params)
|
print(params)
|
||||||
@ -595,7 +632,6 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of parameters: {num_param}")
|
logging.info(f"Number of parameters: {num_param}")
|
||||||
print(f"Number of parameters: {num_param}")
|
|
||||||
|
|
||||||
assert params.start_epoch > 0, params.start_epoch
|
assert params.start_epoch > 0, params.start_epoch
|
||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
@ -609,13 +645,21 @@ def run(rank, world_size, args):
|
|||||||
optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer)
|
optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer)
|
||||||
|
|
||||||
logging.info("About to create datamodule")
|
logging.info("About to create datamodule")
|
||||||
ljspeech = LJSpeechTtsDataModule(args)
|
|
||||||
|
|
||||||
train_cuts = ljspeech.train_cuts()
|
if False:
|
||||||
train_dl = ljspeech.train_dataloaders(train_cuts)
|
params.data_args.tokenizer = tokenizer
|
||||||
|
data_module = TextMelDataModule(hparams=params.data_args)
|
||||||
|
del params.data_args.tokenizer
|
||||||
|
train_dl = data_module.train_dataloader()
|
||||||
|
valid_dl = data_module.val_dataloader()
|
||||||
|
else:
|
||||||
|
ljspeech = LJSpeechTtsDataModule(args)
|
||||||
|
|
||||||
valid_cuts = ljspeech.valid_cuts()
|
train_cuts = ljspeech.train_cuts()
|
||||||
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
|
train_dl = ljspeech.train_dataloaders(train_cuts)
|
||||||
|
|
||||||
|
valid_cuts = ljspeech.valid_cuts()
|
||||||
|
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
@ -625,7 +669,8 @@ def run(rank, world_size, args):
|
|||||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||||
logging.info(f"Start epoch {epoch}")
|
logging.info(f"Start epoch {epoch}")
|
||||||
fix_random_seed(params.seed + epoch - 1)
|
fix_random_seed(params.seed + epoch - 1)
|
||||||
train_dl.sampler.set_epoch(epoch - 1)
|
if "sampler" in train_dl:
|
||||||
|
train_dl.sampler.set_epoch(epoch - 1)
|
||||||
|
|
||||||
params.cur_epoch = epoch
|
params.cur_epoch = epoch
|
||||||
|
|
||||||
|
@ -181,6 +181,8 @@ class LJSpeechTtsDataModule:
|
|||||||
frame_length=1024 / sampling_rate, # (in second),
|
frame_length=1024 / sampling_rate, # (in second),
|
||||||
frame_shift=256 / sampling_rate, # (in second)
|
frame_shift=256 / sampling_rate, # (in second)
|
||||||
use_fft_mag=True,
|
use_fft_mag=True,
|
||||||
|
remove_dc_offset=False,
|
||||||
|
preemph_coeff=0,
|
||||||
low_freq=0,
|
low_freq=0,
|
||||||
high_freq=8000,
|
high_freq=8000,
|
||||||
# should be identical to n_feats in ./train.py
|
# should be identical to n_feats in ./train.py
|
||||||
@ -242,6 +244,8 @@ class LJSpeechTtsDataModule:
|
|||||||
frame_length=1024 / sampling_rate, # (in second),
|
frame_length=1024 / sampling_rate, # (in second),
|
||||||
frame_shift=256 / sampling_rate, # (in second)
|
frame_shift=256 / sampling_rate, # (in second)
|
||||||
use_fft_mag=True,
|
use_fft_mag=True,
|
||||||
|
remove_dc_offset=False,
|
||||||
|
preemph_coeff=0,
|
||||||
low_freq=0,
|
low_freq=0,
|
||||||
high_freq=8000,
|
high_freq=8000,
|
||||||
# should be identical to n_feats in ./train.py
|
# should be identical to n_feats in ./train.py
|
||||||
@ -286,6 +290,8 @@ class LJSpeechTtsDataModule:
|
|||||||
frame_length=1024 / sampling_rate, # (in second),
|
frame_length=1024 / sampling_rate, # (in second),
|
||||||
frame_shift=256 / sampling_rate, # (in second)
|
frame_shift=256 / sampling_rate, # (in second)
|
||||||
use_fft_mag=True,
|
use_fft_mag=True,
|
||||||
|
remove_dc_offset=False,
|
||||||
|
preemph_coeff=0,
|
||||||
low_freq=0,
|
low_freq=0,
|
||||||
high_freq=8000,
|
high_freq=8000,
|
||||||
# should be identical to n_feats in ./train.py
|
# should be identical to n_feats in ./train.py
|
||||||
|
@ -191,3 +191,10 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
|||||||
touch data/fbank/.ljspeech_split.done
|
touch data/fbank/.ljspeech_split.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||||
|
log "Stage 9: Compute fbank mean and std (used by ./matcha)"
|
||||||
|
if [ ! -f ./data/fbank/cmvn.json ]; then
|
||||||
|
./local/compute_fbank_statistics.py ./data/fbank/ljspeech_cuts_train.jsonl.gz ./data/fbank/cmvn.json
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
Loading…
x
Reference in New Issue
Block a user