a little bit coarse commit

This commit is contained in:
JinZr 2024-09-05 22:30:07 +08:00
parent dd82686a0f
commit 6e4a9ea85a
9 changed files with 273 additions and 130 deletions

View File

@ -25,19 +25,16 @@ It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/spectrogram. The generated fbank features are saved in data/spectrogram.
""" """
import argparse
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Optional
import torch import torch
from lhotse import ( from lhotse import CutSet, LilcomChunkyWriter, Spectrogram, SpectrogramConfig
CutSet,
LilcomChunkyWriter,
Spectrogram,
SpectrogramConfig,
load_manifest,
)
from lhotse.audio import RecordingSet from lhotse.audio import RecordingSet
from lhotse.recipes.utils import read_manifests_if_cached
from lhotse.supervision import SupervisionSet from lhotse.supervision import SupervisionSet
from icefall.utils import get_executor from icefall.utils import get_executor
@ -49,26 +46,62 @@ from icefall.utils import get_executor
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
def compute_spectrogram_libritts(): parser.add_argument(
"--dataset",
type=str,
help="""Dataset parts to compute fbank. If None, we will use all""",
)
parser.add_argument(
"--sampling-rate",
type=int,
default=24000,
help="""Sampling rate of the audio for computing fbank, the default value for LibriTTS is 24000, audio files will be resampled if a different sample rate is provided""",
)
return parser.parse_args()
def compute_spectrogram_libritts(dataset: Optional[str] = None, sampling_rate: int = 24000,):
src_dir = Path("data/manifests") src_dir = Path("data/manifests")
output_dir = Path("data/spectrogram") output_dir = Path("data/spectrogram")
num_jobs = min(32, os.cpu_count()) num_jobs = min(32, os.cpu_count())
sampling_rate = 24000
frame_length = 1024 / sampling_rate # (in second) frame_length = 1024 / sampling_rate # (in second)
frame_shift = 256 / sampling_rate # (in second) frame_shift = 256 / sampling_rate # (in second)
use_fft_mag = True use_fft_mag = True
prefix = "libritts" prefix = "libritts"
suffix = "jsonl.gz" suffix = "jsonl.gz"
partition = "all" if dataset is None:
dataset_parts = (
"dev-clean",
"dev-other",
"test-clean",
"test-other",
"train-clean-100",
"train-clean-360",
"train-other-500",
)
else:
dataset_parts = dataset.split(" ", -1)
recordings = load_manifest( manifests = read_manifests_if_cached(
src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet dataset_parts=dataset_parts,
).resample(sampling_rate=sampling_rate) output_dir=src_dir,
supervisions = load_manifest( prefix=prefix,
src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
) )
config = SpectrogramConfig( config = SpectrogramConfig(
@ -80,14 +113,19 @@ def compute_spectrogram_libritts():
extractor = Spectrogram(config) extractor = Spectrogram(config)
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file(): if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.") logging.info(f"{partition} already exists - skipping.")
return return
logging.info(f"Processing {partition}") logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests( cut_set = CutSet.from_manifests(
recordings=recordings, supervisions=supervisions recordings=m["recordings"],
supervisions=m["supervisions"],
) )
if sampling_rate != 24000:
logging.info(f"Resampling audio to {sampling_rate}")
cut_set = cut_set.resample(sampling_rate)
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,

View File

@ -8,6 +8,7 @@ set -eou pipefail
stage=0 stage=0
stop_stage=100 stop_stage=100
sampling_rate=24000 sampling_rate=24000
nj=32
perturb_speed=true perturb_speed=true
dl_dir=$PWD/download dl_dir=$PWD/download
@ -54,7 +55,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
# to $dl_dir/LibriTTS # to $dl_dir/LibriTTS
mkdir -p data/manifests mkdir -p data/manifests
if [ ! -e data/manifests/.libritts.done ]; then if [ ! -e data/manifests/.libritts.done ]; then
lhotse prepare libritts $dl_dir/LibriTTS data/manifests lhotse prepare libritts --num-jobs 32 $dl_dir/LibriTTS data/manifests
touch data/manifests/.libritts.done touch data/manifests/.libritts.done
fi fi
fi fi
@ -84,10 +85,10 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
# Here we shuffle and combine the train-clean-100, train-clean-360 and # Here we shuffle and combine the train-clean-100, train-clean-360 and
# train-other-500 together to form the training set. # train-other-500 together to form the training set.
if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then
cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \ cat <(gunzip -c ./libritts_cuts_train-clean-100.jsonl.gz) \
<(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) \ <(gunzip -c ./libritts_cuts_train-clean-360.jsonl.gz) \
<(gunzip -c data/fbank/libritts_cuts_train-other-500.jsonl.gz) | \ <(gunzip -c ./libritts_cuts_train-other-500.jsonl.gz) | \
shuf | gzip -c > data/fbank/libritts_cuts_train-all-shuf.jsonl.gz shuf | gzip -c > ./libritts_cuts_train-all-shuf.jsonl.gz
fi fi
if [ ! -e data/fbank/.libritts-validated.done ]; then if [ ! -e data/fbank/.libritts-validated.done ]; then

View File

@ -5,9 +5,18 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
from einops import rearrange from einops import rearrange
from utils import get_2d_padding, get_padding from modules.conv import NormConv1d, NormConv2d
from ..modules import NormConv1d, NormConv2d
def get_padding(kernel_size, dilation=1) -> int:
return int((kernel_size * dilation - dilation) / 2)
def get_2d_padding(kernel_size: Tuple[int, int], dilation: Tuple[int, int] = (1, 1)):
return (
((kernel_size[0] - 1) * dilation[0]) // 2,
((kernel_size[1] - 1) * dilation[1]) // 2,
)
class DiscriminatorP(nn.Module): class DiscriminatorP(nn.Module):

View File

@ -80,6 +80,13 @@ class LibriTTSCodecDataModule:
"augmentations, etc.", "augmentations, etc.",
) )
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="""When enabled, use the entire LibriTTS training set.
Otherwise, use the clean-100 subset.""",
)
group.add_argument( group.add_argument(
"--manifest-dir", "--manifest-dir",
type=Path, type=Path,
@ -210,8 +217,8 @@ class LibriTTSCodecDataModule:
validate = SpeechSynthesisDataset( validate = SpeechSynthesisDataset(
return_text=False, return_text=False,
return_tokens=True, return_tokens=False,
return_spk_ids=True, return_spk_ids=False,
feature_input_strategy=eval(self.args.input_strategy)(), feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -236,8 +243,8 @@ class LibriTTSCodecDataModule:
test = SpeechSynthesisDataset( test = SpeechSynthesisDataset(
return_text=False, return_text=False,
return_tokens=True, return_tokens=False,
return_spk_ids=True, return_spk_ids=False,
feature_input_strategy=eval(self.args.input_strategy)(), feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -256,16 +263,60 @@ class LibriTTSCodecDataModule:
return test_dl return test_dl
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz") return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_train-clean-100.jsonl.gz"
)
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get validation cuts") logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz") return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_train-clean-360.jsonl.gz"
)
@lru_cache() @lru_cache()
def test_cuts(self) -> CutSet: def train_other_500_cuts(self) -> CutSet:
logging.info("About to get test cuts") logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz") return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_train-other-500.jsonl.gz"
)
@lru_cache()
def train_all_shuf_cuts(self) -> CutSet:
logging.info(
"About to get the shuffled train-clean-100, \
train-clean-360 and train-other-500 cuts"
)
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz"
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz"
)
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz"
)

View File

@ -1,8 +1,8 @@
from typing import List, Tuple from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
from models.discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT from base_discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT
from torch.nn import AvgPool1d from torch.nn import AvgPool1d
@ -81,7 +81,7 @@ class MultiScaleSTFTDiscriminator(nn.Module):
def __init__( def __init__(
self, self,
filters: int, n_filters: int,
in_channels: int = 1, in_channels: int = 1,
out_channels: int = 1, out_channels: int = 1,
n_ffts: List[int] = [1024, 2048, 512, 256, 128], n_ffts: List[int] = [1024, 2048, 512, 256, 128],
@ -94,7 +94,7 @@ class MultiScaleSTFTDiscriminator(nn.Module):
self.discriminators = nn.ModuleList( self.discriminators = nn.ModuleList(
[ [
DiscriminatorSTFT( DiscriminatorSTFT(
filters, n_filters,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
n_fft=n_ffts[i], n_fft=n_ffts[i],

View File

@ -12,7 +12,7 @@ from torch.cuda.amp import autocast
class Encodec(nn.Module): class Encodec(nn.Module):
def __init__( def __init__(
self, self,
sample_rate: int, sampling_rate: int,
target_bandwidths: List[float], target_bandwidths: List[float],
params: dict, params: dict,
encoder: nn.Module, encoder: nn.Module,
@ -21,21 +21,21 @@ class Encodec(nn.Module):
multi_scale_discriminator: nn.Module, multi_scale_discriminator: nn.Module,
multi_period_discriminator: nn.Module, multi_period_discriminator: nn.Module,
multi_scale_stft_discriminator: nn.Module, multi_scale_stft_discriminator: nn.Module,
cache_generator_outputs: bool = True, cache_generator_outputs: bool = False,
): ):
super(Encodec, self).__init__() super(Encodec, self).__init__()
self.params = params self.params = params
# setup the generator # setup the generator
self.sample_rate = sample_rate self.sampling_rate = sampling_rate
self.encoder = encoder self.encoder = encoder
self.quantizer = quantizer self.quantizer = quantizer
self.decoder = decoder self.decoder = decoder
self.ratios = encoder.ratios self.ratios = encoder.ratios
self.hop_length = np.prod(self.ratios) self.hop_length = np.prod(self.ratios)
self.frame_rate = math.ceil(self.sample_rate / np.prod(self.ratios)) self.frame_rate = math.ceil(self.sampling_rate / np.prod(self.ratios))
self.target_bandwidths = target_bandwidths self.target_bandwidths = target_bandwidths
# discriminators # discriminators
@ -133,10 +133,10 @@ class Encodec(nn.Module):
if return_sample: if return_sample:
stats["returned_sample"] = ( stats["returned_sample"] = (
speech_hat[0].data.cpu().numpy(), speech_hat.cpu(),
speech[0].data.cpu().numpy(), speech.cpu(),
fmap_hat[0][0].data.cpu().numpy(), fmap_hat[0][0].data.cpu(),
fmap[0][0].data.cpu().numpy(), fmap[0][0].data.cpu(),
) )
# reset cache # reset cache
@ -259,3 +259,11 @@ class Encodec(nn.Module):
quantized = self.quantizer.decode(codes) quantized = self.quantizer.decode(codes)
o = self.decoder(quantized) o = self.decoder(quantized)
return o return o
def inference(self, x, target_bw=None, st=None):
# setup
x = x.unsqueeze(1)
codes = self.encode(x, target_bw, st)
o = self.decode(codes)
return o

View File

@ -59,9 +59,9 @@ def sim_loss(y_disc_r, y_disc_gen):
# return torch.sum(loss) / x.shape[0] # return torch.sum(loss) / x.shape[0]
def reconstruction_loss(x, G_x, args, eps=1e-7): def reconstruction_loss(x, x_hat, args, eps=1e-7):
# NOTE (lsx): hard-coded now # NOTE (lsx): hard-coded now
L = args.lambda_wav * F.mse_loss(x, G_x) # wav L1 loss L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss
# loss_sisnr = sisnr_loss(G_x, x) # # loss_sisnr = sisnr_loss(G_x, x) #
# L += 0.01*loss_sisnr # L += 0.01*loss_sisnr
# 2^6=64 -> 2^10=1024 # 2^6=64 -> 2^10=1024
@ -70,15 +70,15 @@ def reconstruction_loss(x, G_x, args, eps=1e-7):
# for i in range(5, 12): # Encodec setting # for i in range(5, 12): # Encodec setting
s = 2**i s = 2**i
melspec = MelSpectrogram( melspec = MelSpectrogram(
sample_rate=args.sr, sample_rate=args.sampling_rate,
n_fft=max(s, 512), n_fft=max(s, 512),
win_length=s, win_length=s,
hop_length=s // 4, hop_length=s // 4,
n_mels=64, n_mels=64,
wkwargs={"device": args.device}, wkwargs={"device": x_hat.device},
).to(args.device) ).to(x_hat.device)
S_x = melspec(x) S_x = melspec(x)
S_G_x = melspec(G_x) S_G_x = melspec(x_hat)
l1_loss = (S_x - S_G_x).abs().mean() l1_loss = (S_x - S_G_x).abs().mean()
l2_loss = ( l2_loss = (
((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps)) ** 2).mean( ((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps)) ** 2).mean(

View File

@ -1,12 +0,0 @@
from typing import Tuple
def get_padding(kernel_size, dilation=1) -> int:
return int((kernel_size * dilation - dilation) / 2)
def get_2d_padding(kernel_size: Tuple[int, int], dilation: Tuple[int, int] = (1, 1)):
return (
((kernel_size[0] - 1) * dilation[0]) // 2,
((kernel_size[1] - 1) * dilation[1]) // 2,
)

View File

@ -2,6 +2,7 @@ import argparse
import itertools import itertools
import logging import logging
import math import math
import random
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
@ -10,6 +11,7 @@ import numpy as np
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from codec_datamodule import LibriTTSCodecDataModule
from encodec import Encodec from encodec import Encodec
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
@ -76,7 +78,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="vits/exp", default="encodec/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -127,6 +129,12 @@ def get_parser():
default=False, default=False,
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--chunk-size",
type=int,
default=1,
help="The chunk size for the dataset (in second).",
)
return parser return parser
@ -249,23 +257,32 @@ def get_model(params: AttributeDict) -> nn.Module:
} }
discriminator_params = { discriminator_params = {
"stft_discriminator_n_filters": 32, "stft_discriminator_n_filters": 32,
"discriminator_iter_start": 500,
}
inference_params = {
"target_bw": 7.5,
} }
params.update(generator_params) params.update(generator_params)
params.update(discriminator_params) params.update(discriminator_params)
params.update(inference_params)
hop_length = np.prod(params.ratios) hop_length = np.prod(params.ratios)
n_q = int( n_q = int(
1000 1000
* params.target_bandwidths[-1] * params.target_bandwidths[-1]
// (math.ceil(params.sample_rate / hop_length) * 10) // (math.ceil(params.sampling_rate / hop_length) * 10)
) )
encoder = SEANetEncoder( encoder = SEANetEncoder(
n_filters=params.n_filters, dimension=params.dimension, ratios=params.ratios n_filters=params.generator_n_filters,
dimension=params.dimension,
ratios=params.ratios,
) )
decoder = SEANetDecoder( decoder = SEANetDecoder(
n_filters=params.n_filters, dimension=params.dimension, ratios=params.ratios n_filters=params.generator_n_filters,
dimension=params.dimension,
ratios=params.ratios,
) )
quantizer = ResidualVectorQuantizer( quantizer = ResidualVectorQuantizer(
dimension=params.dimension, n_q=n_q, bins=params.bins dimension=params.dimension, n_q=n_q, bins=params.bins
@ -273,21 +290,25 @@ def get_model(params: AttributeDict) -> nn.Module:
model = Encodec( model = Encodec(
params=params, params=params,
sample_rate=params.sampling_rate, sampling_rate=params.sampling_rate,
target_bandwidths=params.target_bandwidths, target_bandwidths=params.target_bandwidths,
encoder=encoder, encoder=encoder,
quantizer=quantizer, quantizer=quantizer,
decoder=decoder, decoder=decoder,
multi_scale_discriminator=MultiScaleDiscriminator(), multi_scale_discriminator=MultiScaleDiscriminator(),
multi_period_discriminator=MultiPeriodDiscriminator(), multi_period_discriminator=MultiPeriodDiscriminator(),
multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(), multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(
n_filters=params.stft_discriminator_n_filters
),
) )
return model return model
def prepare_input( def prepare_input(
params: AttributeDict,
batch: dict, batch: dict,
device: torch.device, device: torch.device,
is_training: bool = True,
): ):
"""Parse batch data""" """Parse batch data"""
audio = batch["audio"].to(device, memory_format=torch.contiguous_format) audio = batch["audio"].to(device, memory_format=torch.contiguous_format)
@ -295,6 +316,18 @@ def prepare_input(
audio_lens = batch["audio_lens"].to(device) audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device) features_lens = batch["features_lens"].to(device)
if is_training:
audio_dims = audio.size(-1)
start_idx = random.randint(
0, max(0, audio_dims - params.chunk_size * params.sampling_rate)
)
audio = audio[:, start_idx : params.sampling_rate + start_idx]
else:
# NOTE: a very coarse setup
audio = audio[
:, params.sampling_rate : params.sampling_rate + params.sampling_rate
]
return audio, audio_lens, features, features_lens return audio, audio_lens, features, features_lens
@ -371,13 +404,13 @@ def train_one_epoch(
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["tokens"]) batch_size = len(batch["audio"])
( (
audio, audio,
audio_lens, audio_lens,
_, _,
_, _,
) = prepare_input(batch, device) ) = prepare_input(params, batch, device)
loss_info = MetricsTracker() loss_info = MetricsTracker()
loss_info["samples"] = batch_size loss_info["samples"] = batch_size
@ -476,31 +509,38 @@ def train_one_epoch(
"train/grad_scale", cur_grad_scale, params.batch_idx_train "train/grad_scale", cur_grad_scale, params.batch_idx_train
) )
if "returned_sample" in stats_g: if "returned_sample" in stats_g:
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] # speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
speech_hat_, speech_, _, _ = stats_g["returned_sample"]
speech_hat_i = speech_hat_[0]
speech_i = speech_[0]
if speech_hat_i.dim() > 1:
speech_hat_i = speech_hat_i.squeeze(0)
speech_i = speech_i.squeeze(0)
tb_writer.add_audio( tb_writer.add_audio(
"train/speech_hat_", f"train/speech_hat_",
speech_hat_, speech_hat_i,
params.batch_idx_train, params.batch_idx_train,
params.sampling_rate, params.sampling_rate,
) )
tb_writer.add_audio( tb_writer.add_audio(
"train/speech_", f"train/speech_",
speech_, speech_i,
params.batch_idx_train, params.batch_idx_train,
params.sampling_rate, params.sampling_rate,
) )
tb_writer.add_image( # tb_writer.add_image(
"train/mel_hat_", # "train/mel_hat_",
plot_feature(mel_hat_), # plot_feature(mel_hat_),
params.batch_idx_train, # params.batch_idx_train,
dataformats="HWC", # dataformats="HWC",
) # )
tb_writer.add_image( # tb_writer.add_image(
"train/mel_", # "train/mel_",
plot_feature(mel_), # plot_feature(mel_),
params.batch_idx_train, # params.batch_idx_train,
dataformats="HWC", # dataformats="HWC",
) # )
if ( if (
params.batch_idx_train % params.valid_interval == 0 params.batch_idx_train % params.valid_interval == 0
@ -522,15 +562,20 @@ def train_one_epoch(
valid_info.write_summary( valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
) )
speech_hat_i = speech_hat[0]
speech_i = speech[0]
if speech_hat_i.dim() > 1:
speech_hat_i = speech_hat_i.squeeze(0)
speech_i = speech_i.squeeze(0)
tb_writer.add_audio( tb_writer.add_audio(
"train/valdi_speech_hat", "train/valdi_speech_hat",
speech_hat, speech_hat_i,
params.batch_idx_train, params.batch_idx_train,
params.sampling_rate, params.sampling_rate,
) )
tb_writer.add_audio( tb_writer.add_audio(
"train/valdi_speech", "train/valdi_speech",
speech, speech_i,
params.batch_idx_train, params.batch_idx_train,
params.sampling_rate, params.sampling_rate,
) )
@ -559,13 +604,13 @@ def compute_validation_loss(
with torch.no_grad(): with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
batch_size = len(batch["tokens"]) batch_size = len(batch["audio"])
( (
audio, audio,
audio_lens, audio_lens,
_, _,
_, _,
) = prepare_input(batch, device) ) = prepare_input(params, batch, device, is_training=False)
loss_info = MetricsTracker() loss_info = MetricsTracker()
loss_info["samples"] = batch_size loss_info["samples"] = batch_size
@ -588,7 +633,7 @@ def compute_validation_loss(
speech_lengths=audio_lens, speech_lengths=audio_lens,
global_step=params.batch_idx_train, global_step=params.batch_idx_train,
forward_generator=True, forward_generator=True,
return_sample=batch_idx == 0, return_sample=False,
) )
assert loss_g.requires_grad is False assert loss_g.requires_grad is False
for k, v in stats_g.items(): for k, v in stats_g.items():
@ -599,9 +644,9 @@ def compute_validation_loss(
# infer for first batch: # infer for first batch:
if batch_idx == 0 and rank == 0: if batch_idx == 0 and rank == 0:
speech_hat_, speech_, _, _ = stats_g["returned_sample"] inner_model = model.module if isinstance(model, DDP) else model
audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw)
returned_sample = (speech_hat_, speech_) returned_sample = (audio_pred, audio)
if world_size > 1: if world_size > 1:
tot_loss.reduce(device) tot_loss.reduce(device)
@ -635,7 +680,7 @@ def scan_pessimistic_batches_for_oom(
audio_lens, audio_lens,
_, _,
_, _,
) = prepare_input(batch, device) ) = prepare_input(params, batch, device)
try: try:
# for discriminator # for discriminator
with autocast(enabled=params.use_fp16): with autocast(enabled=params.use_fp16):
@ -706,9 +751,12 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
vctk = VctkTtsDataModule(args) libritts = LibriTTSCodecDataModule(args)
train_cuts = vctk.train_cuts() if params.full_libri:
train_cuts = libritts.train_all_shuf_cuts()
else:
train_cuts = libritts.train_clean_100_cuts()
logging.info(params) logging.info(params)
@ -798,19 +846,19 @@ def run(rank, world_size, args):
if params.inf_check: if params.inf_check:
register_inf_check_hooks(model) register_inf_check_hooks(model)
train_dl = vctk.train_dataloaders(train_cuts) train_dl = libritts.train_dataloaders(train_cuts)
valid_cuts = vctk.valid_cuts() valid_cuts = libritts.dev_clean_cuts()
valid_dl = vctk.valid_dataloaders(valid_cuts) valid_dl = libritts.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: # if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( # scan_pessimistic_batches_for_oom(
model=model, # model=model,
train_dl=train_dl, # train_dl=train_dl,
optimizer_g=optimizer_g, # optimizer_g=optimizer_g,
optimizer_d=optimizer_d, # optimizer_d=optimizer_d,
params=params, # params=params,
) # )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
@ -883,7 +931,7 @@ def run(rank, world_size, args):
def main(): def main():
parser = get_parser() parser = get_parser()
VctkTtsDataModule.add_arguments(parser) LibriTTSCodecDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)