From 6e4a9ea85a50bfaf67cdf4f8e65ba73f07496991 Mon Sep 17 00:00:00 2001 From: JinZr Date: Thu, 5 Sep 2024 22:30:07 +0800 Subject: [PATCH] a little bit coarse commit --- .../ASR/local/compute_spectrogram_libritts.py | 102 +++++++++---- egs/libritts/ASR/prepare.sh | 11 +- ...scriminators.py => base_discriminators.py} | 13 +- .../CODEC/encodec/codec_datamodule.py | 77 ++++++++-- egs/libritts/CODEC/encodec/discriminators.py | 8 +- egs/libritts/CODEC/encodec/encodec.py | 24 ++- egs/libritts/CODEC/encodec/loss.py | 12 +- egs/libritts/CODEC/encodec/models/utils.py | 12 -- egs/libritts/CODEC/encodec/train.py | 144 ++++++++++++------ 9 files changed, 273 insertions(+), 130 deletions(-) rename egs/libritts/CODEC/encodec/{models/discriminators.py => base_discriminators.py} (95%) delete mode 100644 egs/libritts/CODEC/encodec/models/utils.py diff --git a/egs/libritts/ASR/local/compute_spectrogram_libritts.py b/egs/libritts/ASR/local/compute_spectrogram_libritts.py index 181353fdd..6cdc55bc8 100755 --- a/egs/libritts/ASR/local/compute_spectrogram_libritts.py +++ b/egs/libritts/ASR/local/compute_spectrogram_libritts.py @@ -25,19 +25,16 @@ It looks for manifests in the directory data/manifests. The generated fbank features are saved in data/spectrogram. """ +import argparse import logging import os from pathlib import Path +from typing import Optional import torch -from lhotse import ( - CutSet, - LilcomChunkyWriter, - Spectrogram, - SpectrogramConfig, - load_manifest, -) +from lhotse import CutSet, LilcomChunkyWriter, Spectrogram, SpectrogramConfig from lhotse.audio import RecordingSet +from lhotse.recipes.utils import read_manifests_if_cached from lhotse.supervision import SupervisionSet from icefall.utils import get_executor @@ -49,26 +46,62 @@ from icefall.utils import get_executor torch.set_num_threads(1) torch.set_num_interop_threads(1) +def get_args(): + parser = argparse.ArgumentParser() -def compute_spectrogram_libritts(): + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + parser.add_argument( + "--sampling-rate", + type=int, + default=24000, + help="""Sampling rate of the audio for computing fbank, the default value for LibriTTS is 24000, audio files will be resampled if a different sample rate is provided""", + ) + + return parser.parse_args() + + +def compute_spectrogram_libritts(dataset: Optional[str] = None, sampling_rate: int = 24000,): src_dir = Path("data/manifests") output_dir = Path("data/spectrogram") num_jobs = min(32, os.cpu_count()) - sampling_rate = 24000 + frame_length = 1024 / sampling_rate # (in second) frame_shift = 256 / sampling_rate # (in second) use_fft_mag = True prefix = "libritts" suffix = "jsonl.gz" - partition = "all" + if dataset is None: + dataset_parts = ( + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", + ) + else: + dataset_parts = dataset.split(" ", -1) - recordings = load_manifest( - src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet - ).resample(sampling_rate=sampling_rate) - supervisions = load_manifest( - src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, ) config = SpectrogramConfig( @@ -80,24 +113,29 @@ def compute_spectrogram_libritts(): extractor = Spectrogram(config) with get_executor() as ex: # Initialize the executor only once. - cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" - if (output_dir / cuts_filename).is_file(): - logging.info(f"{partition} already exists - skipping.") - return - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=recordings, supervisions=supervisions - ) + for partition, m in manifests.items(): + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + if sampling_rate != 24000: + logging.info(f"Resampling audio to {sampling_rate}") + cut_set = cut_set.resample(sampling_rate) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / cuts_filename) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) if __name__ == "__main__": diff --git a/egs/libritts/ASR/prepare.sh b/egs/libritts/ASR/prepare.sh index 77c3c3842..f3a78bdb8 100755 --- a/egs/libritts/ASR/prepare.sh +++ b/egs/libritts/ASR/prepare.sh @@ -8,6 +8,7 @@ set -eou pipefail stage=0 stop_stage=100 sampling_rate=24000 +nj=32 perturb_speed=true dl_dir=$PWD/download @@ -54,7 +55,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # to $dl_dir/LibriTTS mkdir -p data/manifests if [ ! -e data/manifests/.libritts.done ]; then - lhotse prepare libritts $dl_dir/LibriTTS data/manifests + lhotse prepare libritts --num-jobs 32 $dl_dir/LibriTTS data/manifests touch data/manifests/.libritts.done fi fi @@ -84,10 +85,10 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then # Here we shuffle and combine the train-clean-100, train-clean-360 and # train-other-500 together to form the training set. if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then - cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \ - <(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) \ - <(gunzip -c data/fbank/libritts_cuts_train-other-500.jsonl.gz) | \ - shuf | gzip -c > data/fbank/libritts_cuts_train-all-shuf.jsonl.gz + cat <(gunzip -c ./libritts_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c ./libritts_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c ./libritts_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > ./libritts_cuts_train-all-shuf.jsonl.gz fi if [ ! -e data/fbank/.libritts-validated.done ]; then diff --git a/egs/libritts/CODEC/encodec/models/discriminators.py b/egs/libritts/CODEC/encodec/base_discriminators.py similarity index 95% rename from egs/libritts/CODEC/encodec/models/discriminators.py rename to egs/libritts/CODEC/encodec/base_discriminators.py index 900349b55..e112436e5 100644 --- a/egs/libritts/CODEC/encodec/models/discriminators.py +++ b/egs/libritts/CODEC/encodec/base_discriminators.py @@ -5,9 +5,18 @@ import torch.nn as nn import torch.nn.functional as F import torchaudio from einops import rearrange -from utils import get_2d_padding, get_padding +from modules.conv import NormConv1d, NormConv2d -from ..modules import NormConv1d, NormConv2d + +def get_padding(kernel_size, dilation=1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +def get_2d_padding(kernel_size: Tuple[int, int], dilation: Tuple[int, int] = (1, 1)): + return ( + ((kernel_size[0] - 1) * dilation[0]) // 2, + ((kernel_size[1] - 1) * dilation[1]) // 2, + ) class DiscriminatorP(nn.Module): diff --git a/egs/libritts/CODEC/encodec/codec_datamodule.py b/egs/libritts/CODEC/encodec/codec_datamodule.py index 996569d21..b547e8513 100644 --- a/egs/libritts/CODEC/encodec/codec_datamodule.py +++ b/egs/libritts/CODEC/encodec/codec_datamodule.py @@ -80,6 +80,13 @@ class LibriTTSCodecDataModule: "augmentations, etc.", ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""When enabled, use the entire LibriTTS training set. + Otherwise, use the clean-100 subset.""", + ) group.add_argument( "--manifest-dir", type=Path, @@ -210,8 +217,8 @@ class LibriTTSCodecDataModule: validate = SpeechSynthesisDataset( return_text=False, - return_tokens=True, - return_spk_ids=True, + return_tokens=False, + return_spk_ids=False, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) @@ -236,8 +243,8 @@ class LibriTTSCodecDataModule: test = SpeechSynthesisDataset( return_text=False, - return_tokens=True, - return_spk_ids=True, + return_tokens=False, + return_spk_ids=False, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) @@ -256,16 +263,60 @@ class LibriTTSCodecDataModule: return test_dl @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz") + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-clean-100.jsonl.gz" + ) @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get validation cuts") - return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz") + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-clean-360.jsonl.gz" + ) @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz") + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz" + ) diff --git a/egs/libritts/CODEC/encodec/discriminators.py b/egs/libritts/CODEC/encodec/discriminators.py index 484f1ee43..471aa9244 100644 --- a/egs/libritts/CODEC/encodec/discriminators.py +++ b/egs/libritts/CODEC/encodec/discriminators.py @@ -1,8 +1,8 @@ -from typing import List, Tuple +from typing import List import torch import torch.nn as nn -from models.discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT +from base_discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT from torch.nn import AvgPool1d @@ -81,7 +81,7 @@ class MultiScaleSTFTDiscriminator(nn.Module): def __init__( self, - filters: int, + n_filters: int, in_channels: int = 1, out_channels: int = 1, n_ffts: List[int] = [1024, 2048, 512, 256, 128], @@ -94,7 +94,7 @@ class MultiScaleSTFTDiscriminator(nn.Module): self.discriminators = nn.ModuleList( [ DiscriminatorSTFT( - filters, + n_filters, in_channels=in_channels, out_channels=out_channels, n_fft=n_ffts[i], diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index e7c5ad590..071dc19ba 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -12,7 +12,7 @@ from torch.cuda.amp import autocast class Encodec(nn.Module): def __init__( self, - sample_rate: int, + sampling_rate: int, target_bandwidths: List[float], params: dict, encoder: nn.Module, @@ -21,21 +21,21 @@ class Encodec(nn.Module): multi_scale_discriminator: nn.Module, multi_period_discriminator: nn.Module, multi_scale_stft_discriminator: nn.Module, - cache_generator_outputs: bool = True, + cache_generator_outputs: bool = False, ): super(Encodec, self).__init__() self.params = params # setup the generator - self.sample_rate = sample_rate + self.sampling_rate = sampling_rate self.encoder = encoder self.quantizer = quantizer self.decoder = decoder self.ratios = encoder.ratios self.hop_length = np.prod(self.ratios) - self.frame_rate = math.ceil(self.sample_rate / np.prod(self.ratios)) + self.frame_rate = math.ceil(self.sampling_rate / np.prod(self.ratios)) self.target_bandwidths = target_bandwidths # discriminators @@ -133,10 +133,10 @@ class Encodec(nn.Module): if return_sample: stats["returned_sample"] = ( - speech_hat[0].data.cpu().numpy(), - speech[0].data.cpu().numpy(), - fmap_hat[0][0].data.cpu().numpy(), - fmap[0][0].data.cpu().numpy(), + speech_hat.cpu(), + speech.cpu(), + fmap_hat[0][0].data.cpu(), + fmap[0][0].data.cpu(), ) # reset cache @@ -259,3 +259,11 @@ class Encodec(nn.Module): quantized = self.quantizer.decode(codes) o = self.decoder(quantized) return o + + def inference(self, x, target_bw=None, st=None): + # setup + x = x.unsqueeze(1) + + codes = self.encode(x, target_bw, st) + o = self.decode(codes) + return o diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index 1bb78f283..9ec80f536 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -59,9 +59,9 @@ def sim_loss(y_disc_r, y_disc_gen): # return torch.sum(loss) / x.shape[0] -def reconstruction_loss(x, G_x, args, eps=1e-7): +def reconstruction_loss(x, x_hat, args, eps=1e-7): # NOTE (lsx): hard-coded now - L = args.lambda_wav * F.mse_loss(x, G_x) # wav L1 loss + L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss # loss_sisnr = sisnr_loss(G_x, x) # # L += 0.01*loss_sisnr # 2^6=64 -> 2^10=1024 @@ -70,15 +70,15 @@ def reconstruction_loss(x, G_x, args, eps=1e-7): # for i in range(5, 12): # Encodec setting s = 2**i melspec = MelSpectrogram( - sample_rate=args.sr, + sample_rate=args.sampling_rate, n_fft=max(s, 512), win_length=s, hop_length=s // 4, n_mels=64, - wkwargs={"device": args.device}, - ).to(args.device) + wkwargs={"device": x_hat.device}, + ).to(x_hat.device) S_x = melspec(x) - S_G_x = melspec(G_x) + S_G_x = melspec(x_hat) l1_loss = (S_x - S_G_x).abs().mean() l2_loss = ( ((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps)) ** 2).mean( diff --git a/egs/libritts/CODEC/encodec/models/utils.py b/egs/libritts/CODEC/encodec/models/utils.py deleted file mode 100644 index 2be73a312..000000000 --- a/egs/libritts/CODEC/encodec/models/utils.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Tuple - - -def get_padding(kernel_size, dilation=1) -> int: - return int((kernel_size * dilation - dilation) / 2) - - -def get_2d_padding(kernel_size: Tuple[int, int], dilation: Tuple[int, int] = (1, 1)): - return ( - ((kernel_size[0] - 1) * dilation[0]) // 2, - ((kernel_size[1] - 1) * dilation[1]) // 2, - ) diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 0d08a2e24..6057ba2ab 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -2,6 +2,7 @@ import argparse import itertools import logging import math +import random from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union @@ -10,6 +11,7 @@ import numpy as np import torch import torch.multiprocessing as mp import torch.nn as nn +from codec_datamodule import LibriTTSCodecDataModule from encodec import Encodec from lhotse.cut import Cut from lhotse.utils import fix_random_seed @@ -76,7 +78,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="vits/exp", + default="encodec/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -127,6 +129,12 @@ def get_parser(): default=False, help="Whether to use half precision training.", ) + parser.add_argument( + "--chunk-size", + type=int, + default=1, + help="The chunk size for the dataset (in second).", + ) return parser @@ -249,23 +257,32 @@ def get_model(params: AttributeDict) -> nn.Module: } discriminator_params = { "stft_discriminator_n_filters": 32, + "discriminator_iter_start": 500, + } + inference_params = { + "target_bw": 7.5, } params.update(generator_params) params.update(discriminator_params) + params.update(inference_params) hop_length = np.prod(params.ratios) n_q = int( 1000 * params.target_bandwidths[-1] - // (math.ceil(params.sample_rate / hop_length) * 10) + // (math.ceil(params.sampling_rate / hop_length) * 10) ) encoder = SEANetEncoder( - n_filters=params.n_filters, dimension=params.dimension, ratios=params.ratios + n_filters=params.generator_n_filters, + dimension=params.dimension, + ratios=params.ratios, ) decoder = SEANetDecoder( - n_filters=params.n_filters, dimension=params.dimension, ratios=params.ratios + n_filters=params.generator_n_filters, + dimension=params.dimension, + ratios=params.ratios, ) quantizer = ResidualVectorQuantizer( dimension=params.dimension, n_q=n_q, bins=params.bins @@ -273,21 +290,25 @@ def get_model(params: AttributeDict) -> nn.Module: model = Encodec( params=params, - sample_rate=params.sampling_rate, + sampling_rate=params.sampling_rate, target_bandwidths=params.target_bandwidths, encoder=encoder, quantizer=quantizer, decoder=decoder, multi_scale_discriminator=MultiScaleDiscriminator(), multi_period_discriminator=MultiPeriodDiscriminator(), - multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(), + multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator( + n_filters=params.stft_discriminator_n_filters + ), ) return model def prepare_input( + params: AttributeDict, batch: dict, device: torch.device, + is_training: bool = True, ): """Parse batch data""" audio = batch["audio"].to(device, memory_format=torch.contiguous_format) @@ -295,6 +316,18 @@ def prepare_input( audio_lens = batch["audio_lens"].to(device) features_lens = batch["features_lens"].to(device) + if is_training: + audio_dims = audio.size(-1) + start_idx = random.randint( + 0, max(0, audio_dims - params.chunk_size * params.sampling_rate) + ) + audio = audio[:, start_idx : params.sampling_rate + start_idx] + else: + # NOTE: a very coarse setup + audio = audio[ + :, params.sampling_rate : params.sampling_rate + params.sampling_rate + ] + return audio, audio_lens, features, features_lens @@ -371,13 +404,13 @@ def train_one_epoch( for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 - batch_size = len(batch["tokens"]) + batch_size = len(batch["audio"]) ( audio, audio_lens, _, _, - ) = prepare_input(batch, device) + ) = prepare_input(params, batch, device) loss_info = MetricsTracker() loss_info["samples"] = batch_size @@ -476,31 +509,38 @@ def train_one_epoch( "train/grad_scale", cur_grad_scale, params.batch_idx_train ) if "returned_sample" in stats_g: - speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] + # speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] + speech_hat_, speech_, _, _ = stats_g["returned_sample"] + + speech_hat_i = speech_hat_[0] + speech_i = speech_[0] + if speech_hat_i.dim() > 1: + speech_hat_i = speech_hat_i.squeeze(0) + speech_i = speech_i.squeeze(0) tb_writer.add_audio( - "train/speech_hat_", - speech_hat_, + f"train/speech_hat_", + speech_hat_i, params.batch_idx_train, params.sampling_rate, ) tb_writer.add_audio( - "train/speech_", - speech_, + f"train/speech_", + speech_i, params.batch_idx_train, params.sampling_rate, ) - tb_writer.add_image( - "train/mel_hat_", - plot_feature(mel_hat_), - params.batch_idx_train, - dataformats="HWC", - ) - tb_writer.add_image( - "train/mel_", - plot_feature(mel_), - params.batch_idx_train, - dataformats="HWC", - ) + # tb_writer.add_image( + # "train/mel_hat_", + # plot_feature(mel_hat_), + # params.batch_idx_train, + # dataformats="HWC", + # ) + # tb_writer.add_image( + # "train/mel_", + # plot_feature(mel_), + # params.batch_idx_train, + # dataformats="HWC", + # ) if ( params.batch_idx_train % params.valid_interval == 0 @@ -522,15 +562,20 @@ def train_one_epoch( valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) + speech_hat_i = speech_hat[0] + speech_i = speech[0] + if speech_hat_i.dim() > 1: + speech_hat_i = speech_hat_i.squeeze(0) + speech_i = speech_i.squeeze(0) tb_writer.add_audio( "train/valdi_speech_hat", - speech_hat, + speech_hat_i, params.batch_idx_train, params.sampling_rate, ) tb_writer.add_audio( "train/valdi_speech", - speech, + speech_i, params.batch_idx_train, params.sampling_rate, ) @@ -559,13 +604,13 @@ def compute_validation_loss( with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): - batch_size = len(batch["tokens"]) + batch_size = len(batch["audio"]) ( audio, audio_lens, _, _, - ) = prepare_input(batch, device) + ) = prepare_input(params, batch, device, is_training=False) loss_info = MetricsTracker() loss_info["samples"] = batch_size @@ -588,7 +633,7 @@ def compute_validation_loss( speech_lengths=audio_lens, global_step=params.batch_idx_train, forward_generator=True, - return_sample=batch_idx == 0, + return_sample=False, ) assert loss_g.requires_grad is False for k, v in stats_g.items(): @@ -599,9 +644,9 @@ def compute_validation_loss( # infer for first batch: if batch_idx == 0 and rank == 0: - speech_hat_, speech_, _, _ = stats_g["returned_sample"] - - returned_sample = (speech_hat_, speech_) + inner_model = model.module if isinstance(model, DDP) else model + audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw) + returned_sample = (audio_pred, audio) if world_size > 1: tot_loss.reduce(device) @@ -635,7 +680,7 @@ def scan_pessimistic_batches_for_oom( audio_lens, _, _, - ) = prepare_input(batch, device) + ) = prepare_input(params, batch, device) try: # for discriminator with autocast(enabled=params.use_fp16): @@ -706,9 +751,12 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") - vctk = VctkTtsDataModule(args) + libritts = LibriTTSCodecDataModule(args) - train_cuts = vctk.train_cuts() + if params.full_libri: + train_cuts = libritts.train_all_shuf_cuts() + else: + train_cuts = libritts.train_clean_100_cuts() logging.info(params) @@ -798,19 +846,19 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - train_dl = vctk.train_dataloaders(train_cuts) + train_dl = libritts.train_dataloaders(train_cuts) - valid_cuts = vctk.valid_cuts() - valid_dl = vctk.valid_dataloaders(valid_cuts) + valid_cuts = libritts.dev_clean_cuts() + valid_dl = libritts.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - params=params, - ) + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer_g=optimizer_g, + # optimizer_d=optimizer_d, + # params=params, + # ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: @@ -883,7 +931,7 @@ def run(rank, world_size, args): def main(): parser = get_parser() - VctkTtsDataModule.add_arguments(parser) + LibriTTSCodecDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir)