mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
a little bit coarse commit
This commit is contained in:
parent
dd82686a0f
commit
6e4a9ea85a
@ -25,19 +25,16 @@ It looks for manifests in the directory data/manifests.
|
|||||||
The generated fbank features are saved in data/spectrogram.
|
The generated fbank features are saved in data/spectrogram.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import (
|
from lhotse import CutSet, LilcomChunkyWriter, Spectrogram, SpectrogramConfig
|
||||||
CutSet,
|
|
||||||
LilcomChunkyWriter,
|
|
||||||
Spectrogram,
|
|
||||||
SpectrogramConfig,
|
|
||||||
load_manifest,
|
|
||||||
)
|
|
||||||
from lhotse.audio import RecordingSet
|
from lhotse.audio import RecordingSet
|
||||||
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
from lhotse.supervision import SupervisionSet
|
from lhotse.supervision import SupervisionSet
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
@ -49,26 +46,62 @@ from icefall.utils import get_executor
|
|||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
def compute_spectrogram_libritts():
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
help="""Dataset parts to compute fbank. If None, we will use all""",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sampling-rate",
|
||||||
|
type=int,
|
||||||
|
default=24000,
|
||||||
|
help="""Sampling rate of the audio for computing fbank, the default value for LibriTTS is 24000, audio files will be resampled if a different sample rate is provided""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def compute_spectrogram_libritts(dataset: Optional[str] = None, sampling_rate: int = 24000,):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/spectrogram")
|
output_dir = Path("data/spectrogram")
|
||||||
num_jobs = min(32, os.cpu_count())
|
num_jobs = min(32, os.cpu_count())
|
||||||
|
|
||||||
sampling_rate = 24000
|
|
||||||
frame_length = 1024 / sampling_rate # (in second)
|
frame_length = 1024 / sampling_rate # (in second)
|
||||||
frame_shift = 256 / sampling_rate # (in second)
|
frame_shift = 256 / sampling_rate # (in second)
|
||||||
use_fft_mag = True
|
use_fft_mag = True
|
||||||
|
|
||||||
prefix = "libritts"
|
prefix = "libritts"
|
||||||
suffix = "jsonl.gz"
|
suffix = "jsonl.gz"
|
||||||
partition = "all"
|
if dataset is None:
|
||||||
|
dataset_parts = (
|
||||||
|
"dev-clean",
|
||||||
|
"dev-other",
|
||||||
|
"test-clean",
|
||||||
|
"test-other",
|
||||||
|
"train-clean-100",
|
||||||
|
"train-clean-360",
|
||||||
|
"train-other-500",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dataset_parts = dataset.split(" ", -1)
|
||||||
|
|
||||||
recordings = load_manifest(
|
manifests = read_manifests_if_cached(
|
||||||
src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet
|
dataset_parts=dataset_parts,
|
||||||
).resample(sampling_rate=sampling_rate)
|
output_dir=src_dir,
|
||||||
supervisions = load_manifest(
|
prefix=prefix,
|
||||||
src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet
|
suffix=suffix,
|
||||||
|
)
|
||||||
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = SpectrogramConfig(
|
config = SpectrogramConfig(
|
||||||
@ -80,14 +113,19 @@ def compute_spectrogram_libritts():
|
|||||||
extractor = Spectrogram(config)
|
extractor = Spectrogram(config)
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
|
for partition, m in manifests.items():
|
||||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||||
if (output_dir / cuts_filename).is_file():
|
if (output_dir / cuts_filename).is_file():
|
||||||
logging.info(f"{partition} already exists - skipping.")
|
logging.info(f"{partition} already exists - skipping.")
|
||||||
return
|
return
|
||||||
logging.info(f"Processing {partition}")
|
logging.info(f"Processing {partition}")
|
||||||
cut_set = CutSet.from_manifests(
|
cut_set = CutSet.from_manifests(
|
||||||
recordings=recordings, supervisions=supervisions
|
recordings=m["recordings"],
|
||||||
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
|
if sampling_rate != 24000:
|
||||||
|
logging.info(f"Resampling audio to {sampling_rate}")
|
||||||
|
cut_set = cut_set.resample(sampling_rate)
|
||||||
|
|
||||||
cut_set = cut_set.compute_and_store_features(
|
cut_set = cut_set.compute_and_store_features(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
|
@ -8,6 +8,7 @@ set -eou pipefail
|
|||||||
stage=0
|
stage=0
|
||||||
stop_stage=100
|
stop_stage=100
|
||||||
sampling_rate=24000
|
sampling_rate=24000
|
||||||
|
nj=32
|
||||||
perturb_speed=true
|
perturb_speed=true
|
||||||
|
|
||||||
dl_dir=$PWD/download
|
dl_dir=$PWD/download
|
||||||
@ -54,7 +55,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|||||||
# to $dl_dir/LibriTTS
|
# to $dl_dir/LibriTTS
|
||||||
mkdir -p data/manifests
|
mkdir -p data/manifests
|
||||||
if [ ! -e data/manifests/.libritts.done ]; then
|
if [ ! -e data/manifests/.libritts.done ]; then
|
||||||
lhotse prepare libritts $dl_dir/LibriTTS data/manifests
|
lhotse prepare libritts --num-jobs 32 $dl_dir/LibriTTS data/manifests
|
||||||
touch data/manifests/.libritts.done
|
touch data/manifests/.libritts.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
@ -84,10 +85,10 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
# Here we shuffle and combine the train-clean-100, train-clean-360 and
|
# Here we shuffle and combine the train-clean-100, train-clean-360 and
|
||||||
# train-other-500 together to form the training set.
|
# train-other-500 together to form the training set.
|
||||||
if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then
|
if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then
|
||||||
cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \
|
cat <(gunzip -c ./libritts_cuts_train-clean-100.jsonl.gz) \
|
||||||
<(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) \
|
<(gunzip -c ./libritts_cuts_train-clean-360.jsonl.gz) \
|
||||||
<(gunzip -c data/fbank/libritts_cuts_train-other-500.jsonl.gz) | \
|
<(gunzip -c ./libritts_cuts_train-other-500.jsonl.gz) | \
|
||||||
shuf | gzip -c > data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
|
shuf | gzip -c > ./libritts_cuts_train-all-shuf.jsonl.gz
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ! -e data/fbank/.libritts-validated.done ]; then
|
if [ ! -e data/fbank/.libritts-validated.done ]; then
|
||||||
|
@ -5,9 +5,18 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from utils import get_2d_padding, get_padding
|
from modules.conv import NormConv1d, NormConv2d
|
||||||
|
|
||||||
from ..modules import NormConv1d, NormConv2d
|
|
||||||
|
def get_padding(kernel_size, dilation=1) -> int:
|
||||||
|
return int((kernel_size * dilation - dilation) / 2)
|
||||||
|
|
||||||
|
|
||||||
|
def get_2d_padding(kernel_size: Tuple[int, int], dilation: Tuple[int, int] = (1, 1)):
|
||||||
|
return (
|
||||||
|
((kernel_size[0] - 1) * dilation[0]) // 2,
|
||||||
|
((kernel_size[1] - 1) * dilation[1]) // 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorP(nn.Module):
|
class DiscriminatorP(nn.Module):
|
@ -80,6 +80,13 @@ class LibriTTSCodecDataModule:
|
|||||||
"augmentations, etc.",
|
"augmentations, etc.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--full-libri",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="""When enabled, use the entire LibriTTS training set.
|
||||||
|
Otherwise, use the clean-100 subset.""",
|
||||||
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--manifest-dir",
|
"--manifest-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
@ -210,8 +217,8 @@ class LibriTTSCodecDataModule:
|
|||||||
|
|
||||||
validate = SpeechSynthesisDataset(
|
validate = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=False,
|
||||||
return_tokens=True,
|
return_tokens=False,
|
||||||
return_spk_ids=True,
|
return_spk_ids=False,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
@ -236,8 +243,8 @@ class LibriTTSCodecDataModule:
|
|||||||
|
|
||||||
test = SpeechSynthesisDataset(
|
test = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=False,
|
||||||
return_tokens=True,
|
return_tokens=False,
|
||||||
return_spk_ids=True,
|
return_spk_ids=False,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
@ -256,16 +263,60 @@ class LibriTTSCodecDataModule:
|
|||||||
return test_dl
|
return test_dl
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_cuts(self) -> CutSet:
|
def train_clean_100_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train cuts")
|
logging.info("About to get train-clean-100 cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz")
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_train-clean-100.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def valid_cuts(self) -> CutSet:
|
def train_clean_360_cuts(self) -> CutSet:
|
||||||
logging.info("About to get validation cuts")
|
logging.info("About to get train-clean-360 cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz")
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_train-clean-360.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_cuts(self) -> CutSet:
|
def train_other_500_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test cuts")
|
logging.info("About to get train-other-500 cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz")
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_train-other-500.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_all_shuf_cuts(self) -> CutSet:
|
||||||
|
logging.info(
|
||||||
|
"About to get the shuffled train-clean-100, \
|
||||||
|
train-clean-360 and train-other-500 cuts"
|
||||||
|
)
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def dev_clean_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get dev-clean cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def dev_other_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get dev-other cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_clean_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get test-clean cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_other_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get test-other cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz"
|
||||||
|
)
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from typing import List, Tuple
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from models.discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT
|
from base_discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT
|
||||||
from torch.nn import AvgPool1d
|
from torch.nn import AvgPool1d
|
||||||
|
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ class MultiScaleSTFTDiscriminator(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
filters: int,
|
n_filters: int,
|
||||||
in_channels: int = 1,
|
in_channels: int = 1,
|
||||||
out_channels: int = 1,
|
out_channels: int = 1,
|
||||||
n_ffts: List[int] = [1024, 2048, 512, 256, 128],
|
n_ffts: List[int] = [1024, 2048, 512, 256, 128],
|
||||||
@ -94,7 +94,7 @@ class MultiScaleSTFTDiscriminator(nn.Module):
|
|||||||
self.discriminators = nn.ModuleList(
|
self.discriminators = nn.ModuleList(
|
||||||
[
|
[
|
||||||
DiscriminatorSTFT(
|
DiscriminatorSTFT(
|
||||||
filters,
|
n_filters,
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
n_fft=n_ffts[i],
|
n_fft=n_ffts[i],
|
||||||
|
@ -12,7 +12,7 @@ from torch.cuda.amp import autocast
|
|||||||
class Encodec(nn.Module):
|
class Encodec(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sample_rate: int,
|
sampling_rate: int,
|
||||||
target_bandwidths: List[float],
|
target_bandwidths: List[float],
|
||||||
params: dict,
|
params: dict,
|
||||||
encoder: nn.Module,
|
encoder: nn.Module,
|
||||||
@ -21,21 +21,21 @@ class Encodec(nn.Module):
|
|||||||
multi_scale_discriminator: nn.Module,
|
multi_scale_discriminator: nn.Module,
|
||||||
multi_period_discriminator: nn.Module,
|
multi_period_discriminator: nn.Module,
|
||||||
multi_scale_stft_discriminator: nn.Module,
|
multi_scale_stft_discriminator: nn.Module,
|
||||||
cache_generator_outputs: bool = True,
|
cache_generator_outputs: bool = False,
|
||||||
):
|
):
|
||||||
super(Encodec, self).__init__()
|
super(Encodec, self).__init__()
|
||||||
|
|
||||||
self.params = params
|
self.params = params
|
||||||
|
|
||||||
# setup the generator
|
# setup the generator
|
||||||
self.sample_rate = sample_rate
|
self.sampling_rate = sampling_rate
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.quantizer = quantizer
|
self.quantizer = quantizer
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
|
|
||||||
self.ratios = encoder.ratios
|
self.ratios = encoder.ratios
|
||||||
self.hop_length = np.prod(self.ratios)
|
self.hop_length = np.prod(self.ratios)
|
||||||
self.frame_rate = math.ceil(self.sample_rate / np.prod(self.ratios))
|
self.frame_rate = math.ceil(self.sampling_rate / np.prod(self.ratios))
|
||||||
self.target_bandwidths = target_bandwidths
|
self.target_bandwidths = target_bandwidths
|
||||||
|
|
||||||
# discriminators
|
# discriminators
|
||||||
@ -133,10 +133,10 @@ class Encodec(nn.Module):
|
|||||||
|
|
||||||
if return_sample:
|
if return_sample:
|
||||||
stats["returned_sample"] = (
|
stats["returned_sample"] = (
|
||||||
speech_hat[0].data.cpu().numpy(),
|
speech_hat.cpu(),
|
||||||
speech[0].data.cpu().numpy(),
|
speech.cpu(),
|
||||||
fmap_hat[0][0].data.cpu().numpy(),
|
fmap_hat[0][0].data.cpu(),
|
||||||
fmap[0][0].data.cpu().numpy(),
|
fmap[0][0].data.cpu(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# reset cache
|
# reset cache
|
||||||
@ -259,3 +259,11 @@ class Encodec(nn.Module):
|
|||||||
quantized = self.quantizer.decode(codes)
|
quantized = self.quantizer.decode(codes)
|
||||||
o = self.decoder(quantized)
|
o = self.decoder(quantized)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def inference(self, x, target_bw=None, st=None):
|
||||||
|
# setup
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
|
||||||
|
codes = self.encode(x, target_bw, st)
|
||||||
|
o = self.decode(codes)
|
||||||
|
return o
|
||||||
|
@ -59,9 +59,9 @@ def sim_loss(y_disc_r, y_disc_gen):
|
|||||||
# return torch.sum(loss) / x.shape[0]
|
# return torch.sum(loss) / x.shape[0]
|
||||||
|
|
||||||
|
|
||||||
def reconstruction_loss(x, G_x, args, eps=1e-7):
|
def reconstruction_loss(x, x_hat, args, eps=1e-7):
|
||||||
# NOTE (lsx): hard-coded now
|
# NOTE (lsx): hard-coded now
|
||||||
L = args.lambda_wav * F.mse_loss(x, G_x) # wav L1 loss
|
L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss
|
||||||
# loss_sisnr = sisnr_loss(G_x, x) #
|
# loss_sisnr = sisnr_loss(G_x, x) #
|
||||||
# L += 0.01*loss_sisnr
|
# L += 0.01*loss_sisnr
|
||||||
# 2^6=64 -> 2^10=1024
|
# 2^6=64 -> 2^10=1024
|
||||||
@ -70,15 +70,15 @@ def reconstruction_loss(x, G_x, args, eps=1e-7):
|
|||||||
# for i in range(5, 12): # Encodec setting
|
# for i in range(5, 12): # Encodec setting
|
||||||
s = 2**i
|
s = 2**i
|
||||||
melspec = MelSpectrogram(
|
melspec = MelSpectrogram(
|
||||||
sample_rate=args.sr,
|
sample_rate=args.sampling_rate,
|
||||||
n_fft=max(s, 512),
|
n_fft=max(s, 512),
|
||||||
win_length=s,
|
win_length=s,
|
||||||
hop_length=s // 4,
|
hop_length=s // 4,
|
||||||
n_mels=64,
|
n_mels=64,
|
||||||
wkwargs={"device": args.device},
|
wkwargs={"device": x_hat.device},
|
||||||
).to(args.device)
|
).to(x_hat.device)
|
||||||
S_x = melspec(x)
|
S_x = melspec(x)
|
||||||
S_G_x = melspec(G_x)
|
S_G_x = melspec(x_hat)
|
||||||
l1_loss = (S_x - S_G_x).abs().mean()
|
l1_loss = (S_x - S_G_x).abs().mean()
|
||||||
l2_loss = (
|
l2_loss = (
|
||||||
((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps)) ** 2).mean(
|
((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps)) ** 2).mean(
|
||||||
|
@ -1,12 +0,0 @@
|
|||||||
from typing import Tuple
|
|
||||||
|
|
||||||
|
|
||||||
def get_padding(kernel_size, dilation=1) -> int:
|
|
||||||
return int((kernel_size * dilation - dilation) / 2)
|
|
||||||
|
|
||||||
|
|
||||||
def get_2d_padding(kernel_size: Tuple[int, int], dilation: Tuple[int, int] = (1, 1)):
|
|
||||||
return (
|
|
||||||
((kernel_size[0] - 1) * dilation[0]) // 2,
|
|
||||||
((kernel_size[1] - 1) * dilation[1]) // 2,
|
|
||||||
)
|
|
@ -2,6 +2,7 @@ import argparse
|
|||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
@ -10,6 +11,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from codec_datamodule import LibriTTSCodecDataModule
|
||||||
from encodec import Encodec
|
from encodec import Encodec
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
@ -76,7 +78,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="vits/exp",
|
default="encodec/exp",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
@ -127,6 +129,12 @@ def get_parser():
|
|||||||
default=False,
|
default=False,
|
||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The chunk size for the dataset (in second).",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -249,23 +257,32 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
}
|
}
|
||||||
discriminator_params = {
|
discriminator_params = {
|
||||||
"stft_discriminator_n_filters": 32,
|
"stft_discriminator_n_filters": 32,
|
||||||
|
"discriminator_iter_start": 500,
|
||||||
|
}
|
||||||
|
inference_params = {
|
||||||
|
"target_bw": 7.5,
|
||||||
}
|
}
|
||||||
|
|
||||||
params.update(generator_params)
|
params.update(generator_params)
|
||||||
params.update(discriminator_params)
|
params.update(discriminator_params)
|
||||||
|
params.update(inference_params)
|
||||||
|
|
||||||
hop_length = np.prod(params.ratios)
|
hop_length = np.prod(params.ratios)
|
||||||
n_q = int(
|
n_q = int(
|
||||||
1000
|
1000
|
||||||
* params.target_bandwidths[-1]
|
* params.target_bandwidths[-1]
|
||||||
// (math.ceil(params.sample_rate / hop_length) * 10)
|
// (math.ceil(params.sampling_rate / hop_length) * 10)
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder = SEANetEncoder(
|
encoder = SEANetEncoder(
|
||||||
n_filters=params.n_filters, dimension=params.dimension, ratios=params.ratios
|
n_filters=params.generator_n_filters,
|
||||||
|
dimension=params.dimension,
|
||||||
|
ratios=params.ratios,
|
||||||
)
|
)
|
||||||
decoder = SEANetDecoder(
|
decoder = SEANetDecoder(
|
||||||
n_filters=params.n_filters, dimension=params.dimension, ratios=params.ratios
|
n_filters=params.generator_n_filters,
|
||||||
|
dimension=params.dimension,
|
||||||
|
ratios=params.ratios,
|
||||||
)
|
)
|
||||||
quantizer = ResidualVectorQuantizer(
|
quantizer = ResidualVectorQuantizer(
|
||||||
dimension=params.dimension, n_q=n_q, bins=params.bins
|
dimension=params.dimension, n_q=n_q, bins=params.bins
|
||||||
@ -273,21 +290,25 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
model = Encodec(
|
model = Encodec(
|
||||||
params=params,
|
params=params,
|
||||||
sample_rate=params.sampling_rate,
|
sampling_rate=params.sampling_rate,
|
||||||
target_bandwidths=params.target_bandwidths,
|
target_bandwidths=params.target_bandwidths,
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
quantizer=quantizer,
|
quantizer=quantizer,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
multi_scale_discriminator=MultiScaleDiscriminator(),
|
multi_scale_discriminator=MultiScaleDiscriminator(),
|
||||||
multi_period_discriminator=MultiPeriodDiscriminator(),
|
multi_period_discriminator=MultiPeriodDiscriminator(),
|
||||||
multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(),
|
multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(
|
||||||
|
n_filters=params.stft_discriminator_n_filters
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def prepare_input(
|
def prepare_input(
|
||||||
|
params: AttributeDict,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
is_training: bool = True,
|
||||||
):
|
):
|
||||||
"""Parse batch data"""
|
"""Parse batch data"""
|
||||||
audio = batch["audio"].to(device, memory_format=torch.contiguous_format)
|
audio = batch["audio"].to(device, memory_format=torch.contiguous_format)
|
||||||
@ -295,6 +316,18 @@ def prepare_input(
|
|||||||
audio_lens = batch["audio_lens"].to(device)
|
audio_lens = batch["audio_lens"].to(device)
|
||||||
features_lens = batch["features_lens"].to(device)
|
features_lens = batch["features_lens"].to(device)
|
||||||
|
|
||||||
|
if is_training:
|
||||||
|
audio_dims = audio.size(-1)
|
||||||
|
start_idx = random.randint(
|
||||||
|
0, max(0, audio_dims - params.chunk_size * params.sampling_rate)
|
||||||
|
)
|
||||||
|
audio = audio[:, start_idx : params.sampling_rate + start_idx]
|
||||||
|
else:
|
||||||
|
# NOTE: a very coarse setup
|
||||||
|
audio = audio[
|
||||||
|
:, params.sampling_rate : params.sampling_rate + params.sampling_rate
|
||||||
|
]
|
||||||
|
|
||||||
return audio, audio_lens, features, features_lens
|
return audio, audio_lens, features, features_lens
|
||||||
|
|
||||||
|
|
||||||
@ -371,13 +404,13 @@ def train_one_epoch(
|
|||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
|
|
||||||
batch_size = len(batch["tokens"])
|
batch_size = len(batch["audio"])
|
||||||
(
|
(
|
||||||
audio,
|
audio,
|
||||||
audio_lens,
|
audio_lens,
|
||||||
_,
|
_,
|
||||||
_,
|
_,
|
||||||
) = prepare_input(batch, device)
|
) = prepare_input(params, batch, device)
|
||||||
|
|
||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
loss_info["samples"] = batch_size
|
loss_info["samples"] = batch_size
|
||||||
@ -476,31 +509,38 @@ def train_one_epoch(
|
|||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||||
)
|
)
|
||||||
if "returned_sample" in stats_g:
|
if "returned_sample" in stats_g:
|
||||||
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
|
# speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
|
||||||
|
speech_hat_, speech_, _, _ = stats_g["returned_sample"]
|
||||||
|
|
||||||
|
speech_hat_i = speech_hat_[0]
|
||||||
|
speech_i = speech_[0]
|
||||||
|
if speech_hat_i.dim() > 1:
|
||||||
|
speech_hat_i = speech_hat_i.squeeze(0)
|
||||||
|
speech_i = speech_i.squeeze(0)
|
||||||
tb_writer.add_audio(
|
tb_writer.add_audio(
|
||||||
"train/speech_hat_",
|
f"train/speech_hat_",
|
||||||
speech_hat_,
|
speech_hat_i,
|
||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
params.sampling_rate,
|
params.sampling_rate,
|
||||||
)
|
)
|
||||||
tb_writer.add_audio(
|
tb_writer.add_audio(
|
||||||
"train/speech_",
|
f"train/speech_",
|
||||||
speech_,
|
speech_i,
|
||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
params.sampling_rate,
|
params.sampling_rate,
|
||||||
)
|
)
|
||||||
tb_writer.add_image(
|
# tb_writer.add_image(
|
||||||
"train/mel_hat_",
|
# "train/mel_hat_",
|
||||||
plot_feature(mel_hat_),
|
# plot_feature(mel_hat_),
|
||||||
params.batch_idx_train,
|
# params.batch_idx_train,
|
||||||
dataformats="HWC",
|
# dataformats="HWC",
|
||||||
)
|
# )
|
||||||
tb_writer.add_image(
|
# tb_writer.add_image(
|
||||||
"train/mel_",
|
# "train/mel_",
|
||||||
plot_feature(mel_),
|
# plot_feature(mel_),
|
||||||
params.batch_idx_train,
|
# params.batch_idx_train,
|
||||||
dataformats="HWC",
|
# dataformats="HWC",
|
||||||
)
|
# )
|
||||||
|
|
||||||
if (
|
if (
|
||||||
params.batch_idx_train % params.valid_interval == 0
|
params.batch_idx_train % params.valid_interval == 0
|
||||||
@ -522,15 +562,20 @@ def train_one_epoch(
|
|||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
speech_hat_i = speech_hat[0]
|
||||||
|
speech_i = speech[0]
|
||||||
|
if speech_hat_i.dim() > 1:
|
||||||
|
speech_hat_i = speech_hat_i.squeeze(0)
|
||||||
|
speech_i = speech_i.squeeze(0)
|
||||||
tb_writer.add_audio(
|
tb_writer.add_audio(
|
||||||
"train/valdi_speech_hat",
|
"train/valdi_speech_hat",
|
||||||
speech_hat,
|
speech_hat_i,
|
||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
params.sampling_rate,
|
params.sampling_rate,
|
||||||
)
|
)
|
||||||
tb_writer.add_audio(
|
tb_writer.add_audio(
|
||||||
"train/valdi_speech",
|
"train/valdi_speech",
|
||||||
speech,
|
speech_i,
|
||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
params.sampling_rate,
|
params.sampling_rate,
|
||||||
)
|
)
|
||||||
@ -559,13 +604,13 @@ def compute_validation_loss(
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
batch_size = len(batch["tokens"])
|
batch_size = len(batch["audio"])
|
||||||
(
|
(
|
||||||
audio,
|
audio,
|
||||||
audio_lens,
|
audio_lens,
|
||||||
_,
|
_,
|
||||||
_,
|
_,
|
||||||
) = prepare_input(batch, device)
|
) = prepare_input(params, batch, device, is_training=False)
|
||||||
|
|
||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
loss_info["samples"] = batch_size
|
loss_info["samples"] = batch_size
|
||||||
@ -588,7 +633,7 @@ def compute_validation_loss(
|
|||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
global_step=params.batch_idx_train,
|
global_step=params.batch_idx_train,
|
||||||
forward_generator=True,
|
forward_generator=True,
|
||||||
return_sample=batch_idx == 0,
|
return_sample=False,
|
||||||
)
|
)
|
||||||
assert loss_g.requires_grad is False
|
assert loss_g.requires_grad is False
|
||||||
for k, v in stats_g.items():
|
for k, v in stats_g.items():
|
||||||
@ -599,9 +644,9 @@ def compute_validation_loss(
|
|||||||
|
|
||||||
# infer for first batch:
|
# infer for first batch:
|
||||||
if batch_idx == 0 and rank == 0:
|
if batch_idx == 0 and rank == 0:
|
||||||
speech_hat_, speech_, _, _ = stats_g["returned_sample"]
|
inner_model = model.module if isinstance(model, DDP) else model
|
||||||
|
audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw)
|
||||||
returned_sample = (speech_hat_, speech_)
|
returned_sample = (audio_pred, audio)
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
tot_loss.reduce(device)
|
tot_loss.reduce(device)
|
||||||
@ -635,7 +680,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
audio_lens,
|
audio_lens,
|
||||||
_,
|
_,
|
||||||
_,
|
_,
|
||||||
) = prepare_input(batch, device)
|
) = prepare_input(params, batch, device)
|
||||||
try:
|
try:
|
||||||
# for discriminator
|
# for discriminator
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
@ -706,9 +751,12 @@ def run(rank, world_size, args):
|
|||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
vctk = VctkTtsDataModule(args)
|
libritts = LibriTTSCodecDataModule(args)
|
||||||
|
|
||||||
train_cuts = vctk.train_cuts()
|
if params.full_libri:
|
||||||
|
train_cuts = libritts.train_all_shuf_cuts()
|
||||||
|
else:
|
||||||
|
train_cuts = libritts.train_clean_100_cuts()
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -798,19 +846,19 @@ def run(rank, world_size, args):
|
|||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
train_dl = vctk.train_dataloaders(train_cuts)
|
train_dl = libritts.train_dataloaders(train_cuts)
|
||||||
|
|
||||||
valid_cuts = vctk.valid_cuts()
|
valid_cuts = libritts.dev_clean_cuts()
|
||||||
valid_dl = vctk.valid_dataloaders(valid_cuts)
|
valid_dl = libritts.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
# if not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
# scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
# model=model,
|
||||||
train_dl=train_dl,
|
# train_dl=train_dl,
|
||||||
optimizer_g=optimizer_g,
|
# optimizer_g=optimizer_g,
|
||||||
optimizer_d=optimizer_d,
|
# optimizer_d=optimizer_d,
|
||||||
params=params,
|
# params=params,
|
||||||
)
|
# )
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
@ -883,7 +931,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
VctkTtsDataModule.add_arguments(parser)
|
LibriTTSCodecDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user