a little bit coarse commit

This commit is contained in:
JinZr 2024-09-05 22:30:07 +08:00
parent dd82686a0f
commit 6e4a9ea85a
9 changed files with 273 additions and 130 deletions

View File

@ -25,19 +25,16 @@ It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/spectrogram.
"""
import argparse
import logging
import os
from pathlib import Path
from typing import Optional
import torch
from lhotse import (
CutSet,
LilcomChunkyWriter,
Spectrogram,
SpectrogramConfig,
load_manifest,
)
from lhotse import CutSet, LilcomChunkyWriter, Spectrogram, SpectrogramConfig
from lhotse.audio import RecordingSet
from lhotse.recipes.utils import read_manifests_if_cached
from lhotse.supervision import SupervisionSet
from icefall.utils import get_executor
@ -49,26 +46,62 @@ from icefall.utils import get_executor
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
def compute_spectrogram_libritts():
parser.add_argument(
"--dataset",
type=str,
help="""Dataset parts to compute fbank. If None, we will use all""",
)
parser.add_argument(
"--sampling-rate",
type=int,
default=24000,
help="""Sampling rate of the audio for computing fbank, the default value for LibriTTS is 24000, audio files will be resampled if a different sample rate is provided""",
)
return parser.parse_args()
def compute_spectrogram_libritts(dataset: Optional[str] = None, sampling_rate: int = 24000,):
src_dir = Path("data/manifests")
output_dir = Path("data/spectrogram")
num_jobs = min(32, os.cpu_count())
sampling_rate = 24000
frame_length = 1024 / sampling_rate # (in second)
frame_shift = 256 / sampling_rate # (in second)
use_fft_mag = True
prefix = "libritts"
suffix = "jsonl.gz"
partition = "all"
if dataset is None:
dataset_parts = (
"dev-clean",
"dev-other",
"test-clean",
"test-other",
"train-clean-100",
"train-clean-360",
"train-other-500",
)
else:
dataset_parts = dataset.split(" ", -1)
recordings = load_manifest(
src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet
).resample(sampling_rate=sampling_rate)
supervisions = load_manifest(
src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
config = SpectrogramConfig(
@ -80,24 +113,29 @@ def compute_spectrogram_libritts():
extractor = Spectrogram(config)
with get_executor() as ex: # Initialize the executor only once.
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
return
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=recordings, supervisions=supervisions
)
for partition, m in manifests.items():
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
return
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if sampling_rate != 24000:
logging.info(f"Resampling audio to {sampling_rate}")
cut_set = cut_set.resample(sampling_rate)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / cuts_filename)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / cuts_filename)
if __name__ == "__main__":

View File

@ -8,6 +8,7 @@ set -eou pipefail
stage=0
stop_stage=100
sampling_rate=24000
nj=32
perturb_speed=true
dl_dir=$PWD/download
@ -54,7 +55,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
# to $dl_dir/LibriTTS
mkdir -p data/manifests
if [ ! -e data/manifests/.libritts.done ]; then
lhotse prepare libritts $dl_dir/LibriTTS data/manifests
lhotse prepare libritts --num-jobs 32 $dl_dir/LibriTTS data/manifests
touch data/manifests/.libritts.done
fi
fi
@ -84,10 +85,10 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
# Here we shuffle and combine the train-clean-100, train-clean-360 and
# train-other-500 together to form the training set.
if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then
cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \
<(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) \
<(gunzip -c data/fbank/libritts_cuts_train-other-500.jsonl.gz) | \
shuf | gzip -c > data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
cat <(gunzip -c ./libritts_cuts_train-clean-100.jsonl.gz) \
<(gunzip -c ./libritts_cuts_train-clean-360.jsonl.gz) \
<(gunzip -c ./libritts_cuts_train-other-500.jsonl.gz) | \
shuf | gzip -c > ./libritts_cuts_train-all-shuf.jsonl.gz
fi
if [ ! -e data/fbank/.libritts-validated.done ]; then

View File

@ -5,9 +5,18 @@ import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from einops import rearrange
from utils import get_2d_padding, get_padding
from modules.conv import NormConv1d, NormConv2d
from ..modules import NormConv1d, NormConv2d
def get_padding(kernel_size, dilation=1) -> int:
return int((kernel_size * dilation - dilation) / 2)
def get_2d_padding(kernel_size: Tuple[int, int], dilation: Tuple[int, int] = (1, 1)):
return (
((kernel_size[0] - 1) * dilation[0]) // 2,
((kernel_size[1] - 1) * dilation[1]) // 2,
)
class DiscriminatorP(nn.Module):

View File

@ -80,6 +80,13 @@ class LibriTTSCodecDataModule:
"augmentations, etc.",
)
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="""When enabled, use the entire LibriTTS training set.
Otherwise, use the clean-100 subset.""",
)
group.add_argument(
"--manifest-dir",
type=Path,
@ -210,8 +217,8 @@ class LibriTTSCodecDataModule:
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
return_spk_ids=True,
return_tokens=False,
return_spk_ids=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
@ -236,8 +243,8 @@ class LibriTTSCodecDataModule:
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
return_spk_ids=True,
return_tokens=False,
return_spk_ids=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
@ -256,16 +263,60 @@ class LibriTTSCodecDataModule:
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz")
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_train-clean-100.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get validation cuts")
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz")
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_train-clean-360.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz")
def train_other_500_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_train-other-500.jsonl.gz"
)
@lru_cache()
def train_all_shuf_cuts(self) -> CutSet:
logging.info(
"About to get the shuffled train-clean-100, \
train-clean-360 and train-other-500 cuts"
)
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz"
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz"
)
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz"
)

View File

@ -1,8 +1,8 @@
from typing import List, Tuple
from typing import List
import torch
import torch.nn as nn
from models.discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT
from base_discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT
from torch.nn import AvgPool1d
@ -81,7 +81,7 @@ class MultiScaleSTFTDiscriminator(nn.Module):
def __init__(
self,
filters: int,
n_filters: int,
in_channels: int = 1,
out_channels: int = 1,
n_ffts: List[int] = [1024, 2048, 512, 256, 128],
@ -94,7 +94,7 @@ class MultiScaleSTFTDiscriminator(nn.Module):
self.discriminators = nn.ModuleList(
[
DiscriminatorSTFT(
filters,
n_filters,
in_channels=in_channels,
out_channels=out_channels,
n_fft=n_ffts[i],

View File

@ -12,7 +12,7 @@ from torch.cuda.amp import autocast
class Encodec(nn.Module):
def __init__(
self,
sample_rate: int,
sampling_rate: int,
target_bandwidths: List[float],
params: dict,
encoder: nn.Module,
@ -21,21 +21,21 @@ class Encodec(nn.Module):
multi_scale_discriminator: nn.Module,
multi_period_discriminator: nn.Module,
multi_scale_stft_discriminator: nn.Module,
cache_generator_outputs: bool = True,
cache_generator_outputs: bool = False,
):
super(Encodec, self).__init__()
self.params = params
# setup the generator
self.sample_rate = sample_rate
self.sampling_rate = sampling_rate
self.encoder = encoder
self.quantizer = quantizer
self.decoder = decoder
self.ratios = encoder.ratios
self.hop_length = np.prod(self.ratios)
self.frame_rate = math.ceil(self.sample_rate / np.prod(self.ratios))
self.frame_rate = math.ceil(self.sampling_rate / np.prod(self.ratios))
self.target_bandwidths = target_bandwidths
# discriminators
@ -133,10 +133,10 @@ class Encodec(nn.Module):
if return_sample:
stats["returned_sample"] = (
speech_hat[0].data.cpu().numpy(),
speech[0].data.cpu().numpy(),
fmap_hat[0][0].data.cpu().numpy(),
fmap[0][0].data.cpu().numpy(),
speech_hat.cpu(),
speech.cpu(),
fmap_hat[0][0].data.cpu(),
fmap[0][0].data.cpu(),
)
# reset cache
@ -259,3 +259,11 @@ class Encodec(nn.Module):
quantized = self.quantizer.decode(codes)
o = self.decoder(quantized)
return o
def inference(self, x, target_bw=None, st=None):
# setup
x = x.unsqueeze(1)
codes = self.encode(x, target_bw, st)
o = self.decode(codes)
return o

View File

@ -59,9 +59,9 @@ def sim_loss(y_disc_r, y_disc_gen):
# return torch.sum(loss) / x.shape[0]
def reconstruction_loss(x, G_x, args, eps=1e-7):
def reconstruction_loss(x, x_hat, args, eps=1e-7):
# NOTE (lsx): hard-coded now
L = args.lambda_wav * F.mse_loss(x, G_x) # wav L1 loss
L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss
# loss_sisnr = sisnr_loss(G_x, x) #
# L += 0.01*loss_sisnr
# 2^6=64 -> 2^10=1024
@ -70,15 +70,15 @@ def reconstruction_loss(x, G_x, args, eps=1e-7):
# for i in range(5, 12): # Encodec setting
s = 2**i
melspec = MelSpectrogram(
sample_rate=args.sr,
sample_rate=args.sampling_rate,
n_fft=max(s, 512),
win_length=s,
hop_length=s // 4,
n_mels=64,
wkwargs={"device": args.device},
).to(args.device)
wkwargs={"device": x_hat.device},
).to(x_hat.device)
S_x = melspec(x)
S_G_x = melspec(G_x)
S_G_x = melspec(x_hat)
l1_loss = (S_x - S_G_x).abs().mean()
l2_loss = (
((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps)) ** 2).mean(

View File

@ -1,12 +0,0 @@
from typing import Tuple
def get_padding(kernel_size, dilation=1) -> int:
return int((kernel_size * dilation - dilation) / 2)
def get_2d_padding(kernel_size: Tuple[int, int], dilation: Tuple[int, int] = (1, 1)):
return (
((kernel_size[0] - 1) * dilation[0]) // 2,
((kernel_size[1] - 1) * dilation[1]) // 2,
)

View File

@ -2,6 +2,7 @@ import argparse
import itertools
import logging
import math
import random
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
@ -10,6 +11,7 @@ import numpy as np
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from codec_datamodule import LibriTTSCodecDataModule
from encodec import Encodec
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
@ -76,7 +78,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="vits/exp",
default="encodec/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
@ -127,6 +129,12 @@ def get_parser():
default=False,
help="Whether to use half precision training.",
)
parser.add_argument(
"--chunk-size",
type=int,
default=1,
help="The chunk size for the dataset (in second).",
)
return parser
@ -249,23 +257,32 @@ def get_model(params: AttributeDict) -> nn.Module:
}
discriminator_params = {
"stft_discriminator_n_filters": 32,
"discriminator_iter_start": 500,
}
inference_params = {
"target_bw": 7.5,
}
params.update(generator_params)
params.update(discriminator_params)
params.update(inference_params)
hop_length = np.prod(params.ratios)
n_q = int(
1000
* params.target_bandwidths[-1]
// (math.ceil(params.sample_rate / hop_length) * 10)
// (math.ceil(params.sampling_rate / hop_length) * 10)
)
encoder = SEANetEncoder(
n_filters=params.n_filters, dimension=params.dimension, ratios=params.ratios
n_filters=params.generator_n_filters,
dimension=params.dimension,
ratios=params.ratios,
)
decoder = SEANetDecoder(
n_filters=params.n_filters, dimension=params.dimension, ratios=params.ratios
n_filters=params.generator_n_filters,
dimension=params.dimension,
ratios=params.ratios,
)
quantizer = ResidualVectorQuantizer(
dimension=params.dimension, n_q=n_q, bins=params.bins
@ -273,21 +290,25 @@ def get_model(params: AttributeDict) -> nn.Module:
model = Encodec(
params=params,
sample_rate=params.sampling_rate,
sampling_rate=params.sampling_rate,
target_bandwidths=params.target_bandwidths,
encoder=encoder,
quantizer=quantizer,
decoder=decoder,
multi_scale_discriminator=MultiScaleDiscriminator(),
multi_period_discriminator=MultiPeriodDiscriminator(),
multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(),
multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(
n_filters=params.stft_discriminator_n_filters
),
)
return model
def prepare_input(
params: AttributeDict,
batch: dict,
device: torch.device,
is_training: bool = True,
):
"""Parse batch data"""
audio = batch["audio"].to(device, memory_format=torch.contiguous_format)
@ -295,6 +316,18 @@ def prepare_input(
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
if is_training:
audio_dims = audio.size(-1)
start_idx = random.randint(
0, max(0, audio_dims - params.chunk_size * params.sampling_rate)
)
audio = audio[:, start_idx : params.sampling_rate + start_idx]
else:
# NOTE: a very coarse setup
audio = audio[
:, params.sampling_rate : params.sampling_rate + params.sampling_rate
]
return audio, audio_lens, features, features_lens
@ -371,13 +404,13 @@ def train_one_epoch(
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["tokens"])
batch_size = len(batch["audio"])
(
audio,
audio_lens,
_,
_,
) = prepare_input(batch, device)
) = prepare_input(params, batch, device)
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
@ -476,31 +509,38 @@ def train_one_epoch(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if "returned_sample" in stats_g:
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
# speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
speech_hat_, speech_, _, _ = stats_g["returned_sample"]
speech_hat_i = speech_hat_[0]
speech_i = speech_[0]
if speech_hat_i.dim() > 1:
speech_hat_i = speech_hat_i.squeeze(0)
speech_i = speech_i.squeeze(0)
tb_writer.add_audio(
"train/speech_hat_",
speech_hat_,
f"train/speech_hat_",
speech_hat_i,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_audio(
"train/speech_",
speech_,
f"train/speech_",
speech_i,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_image(
"train/mel_hat_",
plot_feature(mel_hat_),
params.batch_idx_train,
dataformats="HWC",
)
tb_writer.add_image(
"train/mel_",
plot_feature(mel_),
params.batch_idx_train,
dataformats="HWC",
)
# tb_writer.add_image(
# "train/mel_hat_",
# plot_feature(mel_hat_),
# params.batch_idx_train,
# dataformats="HWC",
# )
# tb_writer.add_image(
# "train/mel_",
# plot_feature(mel_),
# params.batch_idx_train,
# dataformats="HWC",
# )
if (
params.batch_idx_train % params.valid_interval == 0
@ -522,15 +562,20 @@ def train_one_epoch(
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
speech_hat_i = speech_hat[0]
speech_i = speech[0]
if speech_hat_i.dim() > 1:
speech_hat_i = speech_hat_i.squeeze(0)
speech_i = speech_i.squeeze(0)
tb_writer.add_audio(
"train/valdi_speech_hat",
speech_hat,
speech_hat_i,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_audio(
"train/valdi_speech",
speech,
speech_i,
params.batch_idx_train,
params.sampling_rate,
)
@ -559,13 +604,13 @@ def compute_validation_loss(
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
batch_size = len(batch["tokens"])
batch_size = len(batch["audio"])
(
audio,
audio_lens,
_,
_,
) = prepare_input(batch, device)
) = prepare_input(params, batch, device, is_training=False)
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
@ -588,7 +633,7 @@ def compute_validation_loss(
speech_lengths=audio_lens,
global_step=params.batch_idx_train,
forward_generator=True,
return_sample=batch_idx == 0,
return_sample=False,
)
assert loss_g.requires_grad is False
for k, v in stats_g.items():
@ -599,9 +644,9 @@ def compute_validation_loss(
# infer for first batch:
if batch_idx == 0 and rank == 0:
speech_hat_, speech_, _, _ = stats_g["returned_sample"]
returned_sample = (speech_hat_, speech_)
inner_model = model.module if isinstance(model, DDP) else model
audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw)
returned_sample = (audio_pred, audio)
if world_size > 1:
tot_loss.reduce(device)
@ -635,7 +680,7 @@ def scan_pessimistic_batches_for_oom(
audio_lens,
_,
_,
) = prepare_input(batch, device)
) = prepare_input(params, batch, device)
try:
# for discriminator
with autocast(enabled=params.use_fp16):
@ -706,9 +751,12 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
vctk = VctkTtsDataModule(args)
libritts = LibriTTSCodecDataModule(args)
train_cuts = vctk.train_cuts()
if params.full_libri:
train_cuts = libritts.train_all_shuf_cuts()
else:
train_cuts = libritts.train_clean_100_cuts()
logging.info(params)
@ -798,19 +846,19 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
train_dl = vctk.train_dataloaders(train_cuts)
train_dl = libritts.train_dataloaders(train_cuts)
valid_cuts = vctk.valid_cuts()
valid_dl = vctk.valid_dataloaders(valid_cuts)
valid_cuts = libritts.dev_clean_cuts()
valid_dl = libritts.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
params=params,
)
# if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer_g=optimizer_g,
# optimizer_d=optimizer_d,
# params=params,
# )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
@ -883,7 +931,7 @@ def run(rank, world_size, args):
def main():
parser = get_parser()
VctkTtsDataModule.add_arguments(parser)
LibriTTSCodecDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)