mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
a little bit coarse commit
This commit is contained in:
parent
dd82686a0f
commit
6e4a9ea85a
@ -25,19 +25,16 @@ It looks for manifests in the directory data/manifests.
|
||||
The generated fbank features are saved in data/spectrogram.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
LilcomChunkyWriter,
|
||||
Spectrogram,
|
||||
SpectrogramConfig,
|
||||
load_manifest,
|
||||
)
|
||||
from lhotse import CutSet, LilcomChunkyWriter, Spectrogram, SpectrogramConfig
|
||||
from lhotse.audio import RecordingSet
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
from lhotse.supervision import SupervisionSet
|
||||
|
||||
from icefall.utils import get_executor
|
||||
@ -49,26 +46,62 @@ from icefall.utils import get_executor
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
def compute_spectrogram_libritts():
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
help="""Dataset parts to compute fbank. If None, we will use all""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=24000,
|
||||
help="""Sampling rate of the audio for computing fbank, the default value for LibriTTS is 24000, audio files will be resampled if a different sample rate is provided""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compute_spectrogram_libritts(dataset: Optional[str] = None, sampling_rate: int = 24000,):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/spectrogram")
|
||||
num_jobs = min(32, os.cpu_count())
|
||||
|
||||
sampling_rate = 24000
|
||||
|
||||
frame_length = 1024 / sampling_rate # (in second)
|
||||
frame_shift = 256 / sampling_rate # (in second)
|
||||
use_fft_mag = True
|
||||
|
||||
prefix = "libritts"
|
||||
suffix = "jsonl.gz"
|
||||
partition = "all"
|
||||
if dataset is None:
|
||||
dataset_parts = (
|
||||
"dev-clean",
|
||||
"dev-other",
|
||||
"test-clean",
|
||||
"test-other",
|
||||
"train-clean-100",
|
||||
"train-clean-360",
|
||||
"train-other-500",
|
||||
)
|
||||
else:
|
||||
dataset_parts = dataset.split(" ", -1)
|
||||
|
||||
recordings = load_manifest(
|
||||
src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet
|
||||
).resample(sampling_rate=sampling_rate)
|
||||
supervisions = load_manifest(
|
||||
src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
assert len(manifests) == len(dataset_parts), (
|
||||
len(manifests),
|
||||
len(dataset_parts),
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
config = SpectrogramConfig(
|
||||
@ -80,14 +113,19 @@ def compute_spectrogram_libritts():
|
||||
extractor = Spectrogram(config)
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||
if (output_dir / cuts_filename).is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
return
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=recordings, supervisions=supervisions
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if sampling_rate != 24000:
|
||||
logging.info(f"Resampling audio to {sampling_rate}")
|
||||
cut_set = cut_set.resample(sampling_rate)
|
||||
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
|
@ -8,6 +8,7 @@ set -eou pipefail
|
||||
stage=0
|
||||
stop_stage=100
|
||||
sampling_rate=24000
|
||||
nj=32
|
||||
perturb_speed=true
|
||||
|
||||
dl_dir=$PWD/download
|
||||
@ -54,7 +55,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
# to $dl_dir/LibriTTS
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.libritts.done ]; then
|
||||
lhotse prepare libritts $dl_dir/LibriTTS data/manifests
|
||||
lhotse prepare libritts --num-jobs 32 $dl_dir/LibriTTS data/manifests
|
||||
touch data/manifests/.libritts.done
|
||||
fi
|
||||
fi
|
||||
@ -84,10 +85,10 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
# Here we shuffle and combine the train-clean-100, train-clean-360 and
|
||||
# train-other-500 together to form the training set.
|
||||
if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then
|
||||
cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \
|
||||
<(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) \
|
||||
<(gunzip -c data/fbank/libritts_cuts_train-other-500.jsonl.gz) | \
|
||||
shuf | gzip -c > data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
|
||||
cat <(gunzip -c ./libritts_cuts_train-clean-100.jsonl.gz) \
|
||||
<(gunzip -c ./libritts_cuts_train-clean-360.jsonl.gz) \
|
||||
<(gunzip -c ./libritts_cuts_train-other-500.jsonl.gz) | \
|
||||
shuf | gzip -c > ./libritts_cuts_train-all-shuf.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ ! -e data/fbank/.libritts-validated.done ]; then
|
||||
|
@ -5,9 +5,18 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from einops import rearrange
|
||||
from utils import get_2d_padding, get_padding
|
||||
from modules.conv import NormConv1d, NormConv2d
|
||||
|
||||
from ..modules import NormConv1d, NormConv2d
|
||||
|
||||
def get_padding(kernel_size, dilation=1) -> int:
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def get_2d_padding(kernel_size: Tuple[int, int], dilation: Tuple[int, int] = (1, 1)):
|
||||
return (
|
||||
((kernel_size[0] - 1) * dilation[0]) // 2,
|
||||
((kernel_size[1] - 1) * dilation[1]) // 2,
|
||||
)
|
||||
|
||||
|
||||
class DiscriminatorP(nn.Module):
|
@ -80,6 +80,13 @@ class LibriTTSCodecDataModule:
|
||||
"augmentations, etc.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--full-libri",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="""When enabled, use the entire LibriTTS training set.
|
||||
Otherwise, use the clean-100 subset.""",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
@ -210,8 +217,8 @@ class LibriTTSCodecDataModule:
|
||||
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
return_tokens=False,
|
||||
return_spk_ids=False,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
@ -236,8 +243,8 @@ class LibriTTSCodecDataModule:
|
||||
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
return_tokens=False,
|
||||
return_spk_ids=False,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
@ -256,16 +263,60 @@ class LibriTTSCodecDataModule:
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz")
|
||||
def train_clean_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-100 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_train-clean-100.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get validation cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz")
|
||||
def train_clean_360_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-360 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_train-clean-360.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz")
|
||||
def train_other_500_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-other-500 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_train-other-500.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_all_shuf_cuts(self) -> CutSet:
|
||||
logging.info(
|
||||
"About to get the shuffled train-clean-100, \
|
||||
train-clean-360 and train-other-500 cuts"
|
||||
)
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz"
|
||||
)
|
||||
|
@ -1,8 +1,8 @@
|
||||
from typing import List, Tuple
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from models.discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT
|
||||
from base_discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT
|
||||
from torch.nn import AvgPool1d
|
||||
|
||||
|
||||
@ -81,7 +81,7 @@ class MultiScaleSTFTDiscriminator(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filters: int,
|
||||
n_filters: int,
|
||||
in_channels: int = 1,
|
||||
out_channels: int = 1,
|
||||
n_ffts: List[int] = [1024, 2048, 512, 256, 128],
|
||||
@ -94,7 +94,7 @@ class MultiScaleSTFTDiscriminator(nn.Module):
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorSTFT(
|
||||
filters,
|
||||
n_filters,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
n_fft=n_ffts[i],
|
||||
|
@ -12,7 +12,7 @@ from torch.cuda.amp import autocast
|
||||
class Encodec(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int,
|
||||
sampling_rate: int,
|
||||
target_bandwidths: List[float],
|
||||
params: dict,
|
||||
encoder: nn.Module,
|
||||
@ -21,21 +21,21 @@ class Encodec(nn.Module):
|
||||
multi_scale_discriminator: nn.Module,
|
||||
multi_period_discriminator: nn.Module,
|
||||
multi_scale_stft_discriminator: nn.Module,
|
||||
cache_generator_outputs: bool = True,
|
||||
cache_generator_outputs: bool = False,
|
||||
):
|
||||
super(Encodec, self).__init__()
|
||||
|
||||
self.params = params
|
||||
|
||||
# setup the generator
|
||||
self.sample_rate = sample_rate
|
||||
self.sampling_rate = sampling_rate
|
||||
self.encoder = encoder
|
||||
self.quantizer = quantizer
|
||||
self.decoder = decoder
|
||||
|
||||
self.ratios = encoder.ratios
|
||||
self.hop_length = np.prod(self.ratios)
|
||||
self.frame_rate = math.ceil(self.sample_rate / np.prod(self.ratios))
|
||||
self.frame_rate = math.ceil(self.sampling_rate / np.prod(self.ratios))
|
||||
self.target_bandwidths = target_bandwidths
|
||||
|
||||
# discriminators
|
||||
@ -133,10 +133,10 @@ class Encodec(nn.Module):
|
||||
|
||||
if return_sample:
|
||||
stats["returned_sample"] = (
|
||||
speech_hat[0].data.cpu().numpy(),
|
||||
speech[0].data.cpu().numpy(),
|
||||
fmap_hat[0][0].data.cpu().numpy(),
|
||||
fmap[0][0].data.cpu().numpy(),
|
||||
speech_hat.cpu(),
|
||||
speech.cpu(),
|
||||
fmap_hat[0][0].data.cpu(),
|
||||
fmap[0][0].data.cpu(),
|
||||
)
|
||||
|
||||
# reset cache
|
||||
@ -259,3 +259,11 @@ class Encodec(nn.Module):
|
||||
quantized = self.quantizer.decode(codes)
|
||||
o = self.decoder(quantized)
|
||||
return o
|
||||
|
||||
def inference(self, x, target_bw=None, st=None):
|
||||
# setup
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
codes = self.encode(x, target_bw, st)
|
||||
o = self.decode(codes)
|
||||
return o
|
||||
|
@ -59,9 +59,9 @@ def sim_loss(y_disc_r, y_disc_gen):
|
||||
# return torch.sum(loss) / x.shape[0]
|
||||
|
||||
|
||||
def reconstruction_loss(x, G_x, args, eps=1e-7):
|
||||
def reconstruction_loss(x, x_hat, args, eps=1e-7):
|
||||
# NOTE (lsx): hard-coded now
|
||||
L = args.lambda_wav * F.mse_loss(x, G_x) # wav L1 loss
|
||||
L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss
|
||||
# loss_sisnr = sisnr_loss(G_x, x) #
|
||||
# L += 0.01*loss_sisnr
|
||||
# 2^6=64 -> 2^10=1024
|
||||
@ -70,15 +70,15 @@ def reconstruction_loss(x, G_x, args, eps=1e-7):
|
||||
# for i in range(5, 12): # Encodec setting
|
||||
s = 2**i
|
||||
melspec = MelSpectrogram(
|
||||
sample_rate=args.sr,
|
||||
sample_rate=args.sampling_rate,
|
||||
n_fft=max(s, 512),
|
||||
win_length=s,
|
||||
hop_length=s // 4,
|
||||
n_mels=64,
|
||||
wkwargs={"device": args.device},
|
||||
).to(args.device)
|
||||
wkwargs={"device": x_hat.device},
|
||||
).to(x_hat.device)
|
||||
S_x = melspec(x)
|
||||
S_G_x = melspec(G_x)
|
||||
S_G_x = melspec(x_hat)
|
||||
l1_loss = (S_x - S_G_x).abs().mean()
|
||||
l2_loss = (
|
||||
((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps)) ** 2).mean(
|
||||
|
@ -1,12 +0,0 @@
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1) -> int:
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def get_2d_padding(kernel_size: Tuple[int, int], dilation: Tuple[int, int] = (1, 1)):
|
||||
return (
|
||||
((kernel_size[0] - 1) * dilation[0]) // 2,
|
||||
((kernel_size[1] - 1) * dilation[1]) // 2,
|
||||
)
|
@ -2,6 +2,7 @@ import argparse
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
@ -10,6 +11,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from codec_datamodule import LibriTTSCodecDataModule
|
||||
from encodec import Encodec
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
@ -76,7 +78,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="vits/exp",
|
||||
default="encodec/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -127,6 +129,12 @@ def get_parser():
|
||||
default=False,
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The chunk size for the dataset (in second).",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@ -249,23 +257,32 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
}
|
||||
discriminator_params = {
|
||||
"stft_discriminator_n_filters": 32,
|
||||
"discriminator_iter_start": 500,
|
||||
}
|
||||
inference_params = {
|
||||
"target_bw": 7.5,
|
||||
}
|
||||
|
||||
params.update(generator_params)
|
||||
params.update(discriminator_params)
|
||||
params.update(inference_params)
|
||||
|
||||
hop_length = np.prod(params.ratios)
|
||||
n_q = int(
|
||||
1000
|
||||
* params.target_bandwidths[-1]
|
||||
// (math.ceil(params.sample_rate / hop_length) * 10)
|
||||
// (math.ceil(params.sampling_rate / hop_length) * 10)
|
||||
)
|
||||
|
||||
encoder = SEANetEncoder(
|
||||
n_filters=params.n_filters, dimension=params.dimension, ratios=params.ratios
|
||||
n_filters=params.generator_n_filters,
|
||||
dimension=params.dimension,
|
||||
ratios=params.ratios,
|
||||
)
|
||||
decoder = SEANetDecoder(
|
||||
n_filters=params.n_filters, dimension=params.dimension, ratios=params.ratios
|
||||
n_filters=params.generator_n_filters,
|
||||
dimension=params.dimension,
|
||||
ratios=params.ratios,
|
||||
)
|
||||
quantizer = ResidualVectorQuantizer(
|
||||
dimension=params.dimension, n_q=n_q, bins=params.bins
|
||||
@ -273,21 +290,25 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
|
||||
model = Encodec(
|
||||
params=params,
|
||||
sample_rate=params.sampling_rate,
|
||||
sampling_rate=params.sampling_rate,
|
||||
target_bandwidths=params.target_bandwidths,
|
||||
encoder=encoder,
|
||||
quantizer=quantizer,
|
||||
decoder=decoder,
|
||||
multi_scale_discriminator=MultiScaleDiscriminator(),
|
||||
multi_period_discriminator=MultiPeriodDiscriminator(),
|
||||
multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(),
|
||||
multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(
|
||||
n_filters=params.stft_discriminator_n_filters
|
||||
),
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def prepare_input(
|
||||
params: AttributeDict,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
is_training: bool = True,
|
||||
):
|
||||
"""Parse batch data"""
|
||||
audio = batch["audio"].to(device, memory_format=torch.contiguous_format)
|
||||
@ -295,6 +316,18 @@ def prepare_input(
|
||||
audio_lens = batch["audio_lens"].to(device)
|
||||
features_lens = batch["features_lens"].to(device)
|
||||
|
||||
if is_training:
|
||||
audio_dims = audio.size(-1)
|
||||
start_idx = random.randint(
|
||||
0, max(0, audio_dims - params.chunk_size * params.sampling_rate)
|
||||
)
|
||||
audio = audio[:, start_idx : params.sampling_rate + start_idx]
|
||||
else:
|
||||
# NOTE: a very coarse setup
|
||||
audio = audio[
|
||||
:, params.sampling_rate : params.sampling_rate + params.sampling_rate
|
||||
]
|
||||
|
||||
return audio, audio_lens, features, features_lens
|
||||
|
||||
|
||||
@ -371,13 +404,13 @@ def train_one_epoch(
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
|
||||
batch_size = len(batch["tokens"])
|
||||
batch_size = len(batch["audio"])
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
_,
|
||||
_,
|
||||
) = prepare_input(batch, device)
|
||||
) = prepare_input(params, batch, device)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
@ -476,31 +509,38 @@ def train_one_epoch(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
)
|
||||
if "returned_sample" in stats_g:
|
||||
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
|
||||
# speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
|
||||
speech_hat_, speech_, _, _ = stats_g["returned_sample"]
|
||||
|
||||
speech_hat_i = speech_hat_[0]
|
||||
speech_i = speech_[0]
|
||||
if speech_hat_i.dim() > 1:
|
||||
speech_hat_i = speech_hat_i.squeeze(0)
|
||||
speech_i = speech_i.squeeze(0)
|
||||
tb_writer.add_audio(
|
||||
"train/speech_hat_",
|
||||
speech_hat_,
|
||||
f"train/speech_hat_",
|
||||
speech_hat_i,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/speech_",
|
||||
speech_,
|
||||
f"train/speech_",
|
||||
speech_i,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
tb_writer.add_image(
|
||||
"train/mel_hat_",
|
||||
plot_feature(mel_hat_),
|
||||
params.batch_idx_train,
|
||||
dataformats="HWC",
|
||||
)
|
||||
tb_writer.add_image(
|
||||
"train/mel_",
|
||||
plot_feature(mel_),
|
||||
params.batch_idx_train,
|
||||
dataformats="HWC",
|
||||
)
|
||||
# tb_writer.add_image(
|
||||
# "train/mel_hat_",
|
||||
# plot_feature(mel_hat_),
|
||||
# params.batch_idx_train,
|
||||
# dataformats="HWC",
|
||||
# )
|
||||
# tb_writer.add_image(
|
||||
# "train/mel_",
|
||||
# plot_feature(mel_),
|
||||
# params.batch_idx_train,
|
||||
# dataformats="HWC",
|
||||
# )
|
||||
|
||||
if (
|
||||
params.batch_idx_train % params.valid_interval == 0
|
||||
@ -522,15 +562,20 @@ def train_one_epoch(
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
speech_hat_i = speech_hat[0]
|
||||
speech_i = speech[0]
|
||||
if speech_hat_i.dim() > 1:
|
||||
speech_hat_i = speech_hat_i.squeeze(0)
|
||||
speech_i = speech_i.squeeze(0)
|
||||
tb_writer.add_audio(
|
||||
"train/valdi_speech_hat",
|
||||
speech_hat,
|
||||
speech_hat_i,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/valdi_speech",
|
||||
speech,
|
||||
speech_i,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
@ -559,13 +604,13 @@ def compute_validation_loss(
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
batch_size = len(batch["tokens"])
|
||||
batch_size = len(batch["audio"])
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
_,
|
||||
_,
|
||||
) = prepare_input(batch, device)
|
||||
) = prepare_input(params, batch, device, is_training=False)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
@ -588,7 +633,7 @@ def compute_validation_loss(
|
||||
speech_lengths=audio_lens,
|
||||
global_step=params.batch_idx_train,
|
||||
forward_generator=True,
|
||||
return_sample=batch_idx == 0,
|
||||
return_sample=False,
|
||||
)
|
||||
assert loss_g.requires_grad is False
|
||||
for k, v in stats_g.items():
|
||||
@ -599,9 +644,9 @@ def compute_validation_loss(
|
||||
|
||||
# infer for first batch:
|
||||
if batch_idx == 0 and rank == 0:
|
||||
speech_hat_, speech_, _, _ = stats_g["returned_sample"]
|
||||
|
||||
returned_sample = (speech_hat_, speech_)
|
||||
inner_model = model.module if isinstance(model, DDP) else model
|
||||
audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw)
|
||||
returned_sample = (audio_pred, audio)
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(device)
|
||||
@ -635,7 +680,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
audio_lens,
|
||||
_,
|
||||
_,
|
||||
) = prepare_input(batch, device)
|
||||
) = prepare_input(params, batch, device)
|
||||
try:
|
||||
# for discriminator
|
||||
with autocast(enabled=params.use_fp16):
|
||||
@ -706,9 +751,12 @@ def run(rank, world_size, args):
|
||||
device = torch.device("cuda", rank)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
vctk = VctkTtsDataModule(args)
|
||||
libritts = LibriTTSCodecDataModule(args)
|
||||
|
||||
train_cuts = vctk.train_cuts()
|
||||
if params.full_libri:
|
||||
train_cuts = libritts.train_all_shuf_cuts()
|
||||
else:
|
||||
train_cuts = libritts.train_clean_100_cuts()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
@ -798,19 +846,19 @@ def run(rank, world_size, args):
|
||||
if params.inf_check:
|
||||
register_inf_check_hooks(model)
|
||||
|
||||
train_dl = vctk.train_dataloaders(train_cuts)
|
||||
train_dl = libritts.train_dataloaders(train_cuts)
|
||||
|
||||
valid_cuts = vctk.valid_cuts()
|
||||
valid_dl = vctk.valid_dataloaders(valid_cuts)
|
||||
valid_cuts = libritts.dev_clean_cuts()
|
||||
valid_dl = libritts.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
params=params,
|
||||
)
|
||||
# if not params.print_diagnostics:
|
||||
# scan_pessimistic_batches_for_oom(
|
||||
# model=model,
|
||||
# train_dl=train_dl,
|
||||
# optimizer_g=optimizer_g,
|
||||
# optimizer_d=optimizer_d,
|
||||
# params=params,
|
||||
# )
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
@ -883,7 +931,7 @@ def run(rank, world_size, args):
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
VctkTtsDataModule.add_arguments(parser)
|
||||
LibriTTSCodecDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user