diff --git a/egs/vctk/TTS/local/compute_spectrogram_vctk.py b/egs/vctk/TTS/local/compute_spectrogram_vctk.py index 0627281c3..440ac1245 100755 --- a/egs/vctk/TTS/local/compute_spectrogram_vctk.py +++ b/egs/vctk/TTS/local/compute_spectrogram_vctk.py @@ -32,9 +32,9 @@ from pathlib import Path import torch from lhotse import ( CutSet, + LilcomChunkyWriter, Spectrogram, SpectrogramConfig, - LilcomChunkyWriter, load_manifest, ) from lhotse.audio import RecordingSet diff --git a/egs/vctk/TTS/vits/duration_predictor.py b/egs/vctk/TTS/vits/duration_predictor.py index c29a28479..1a8190014 100644 --- a/egs/vctk/TTS/vits/duration_predictor.py +++ b/egs/vctk/TTS/vits/duration_predictor.py @@ -14,7 +14,6 @@ from typing import Optional import torch import torch.nn.functional as F - from flow import ( ConvFlow, DilatedDepthSeparableConv, diff --git a/egs/vctk/TTS/vits/flow.py b/egs/vctk/TTS/vits/flow.py index 206bd5e3e..2b84f6434 100644 --- a/egs/vctk/TTS/vits/flow.py +++ b/egs/vctk/TTS/vits/flow.py @@ -13,7 +13,6 @@ import math from typing import Optional, Tuple, Union import torch - from transform import piecewise_rational_quadratic_transform diff --git a/egs/vctk/TTS/vits/generator.py b/egs/vctk/TTS/vits/generator.py index 664d8064f..634b2061a 100644 --- a/egs/vctk/TTS/vits/generator.py +++ b/egs/vctk/TTS/vits/generator.py @@ -16,9 +16,6 @@ from typing import List, Optional, Tuple import numpy as np import torch import torch.nn.functional as F - -from icefall.utils import make_pad_mask - from duration_predictor import StochasticDurationPredictor from hifigan import HiFiGANGenerator from posterior_encoder import PosteriorEncoder @@ -26,6 +23,8 @@ from residual_coupling import ResidualAffineCouplingBlock from text_encoder import TextEncoder from utils import get_random_segments +from icefall.utils import make_pad_mask + class VITSGenerator(torch.nn.Module): """Generator module in VITS, `Conditional Variational Autoencoder diff --git a/egs/vctk/TTS/vits/infer.py b/egs/vctk/TTS/vits/infer.py index 9bc614ad2..f01c5bbc4 100755 --- a/egs/vctk/TTS/vits/infer.py +++ b/egs/vctk/TTS/vits/infer.py @@ -36,13 +36,12 @@ import k2 import torch import torch.nn as nn import torchaudio - -from train import get_model, get_params from tokenizer import Tokenizer +from train import get_model, get_params +from tts_datamodule import LJSpeechTtsDataModule from icefall.checkpoint import load_checkpoint from icefall.utils import AttributeDict, setup_logger -from tts_datamodule import LJSpeechTtsDataModule def get_parser(): diff --git a/egs/vctk/TTS/vits/loss.py b/egs/vctk/TTS/vits/loss.py index 21aaad6e7..2f4dc9bc0 100644 --- a/egs/vctk/TTS/vits/loss.py +++ b/egs/vctk/TTS/vits/loss.py @@ -14,7 +14,6 @@ from typing import List, Tuple, Union import torch import torch.distributions as D import torch.nn.functional as F - from lhotse.features.kaldi import Wav2LogFilterBank diff --git a/egs/vctk/TTS/vits/posterior_encoder.py b/egs/vctk/TTS/vits/posterior_encoder.py index 6b8a5be52..1104fb864 100644 --- a/egs/vctk/TTS/vits/posterior_encoder.py +++ b/egs/vctk/TTS/vits/posterior_encoder.py @@ -12,9 +12,9 @@ This code is based on https://github.com/jaywalnut310/vits. from typing import Optional, Tuple import torch +from wavenet import Conv1d, WaveNet from icefall.utils import make_pad_mask -from wavenet import WaveNet, Conv1d class PosteriorEncoder(torch.nn.Module): diff --git a/egs/vctk/TTS/vits/residual_coupling.py b/egs/vctk/TTS/vits/residual_coupling.py index 2d6807cb7..f9a2a3786 100644 --- a/egs/vctk/TTS/vits/residual_coupling.py +++ b/egs/vctk/TTS/vits/residual_coupling.py @@ -12,7 +12,6 @@ This code is based on https://github.com/jaywalnut310/vits. from typing import Optional, Tuple, Union import torch - from flow import FlipFlow from wavenet import WaveNet diff --git a/egs/vctk/TTS/vits/train.py b/egs/vctk/TTS/vits/train.py index d05b7c668..0c6ca1b4d 100755 --- a/egs/vctk/TTS/vits/train.py +++ b/egs/vctk/TTS/vits/train.py @@ -18,21 +18,25 @@ import argparse import logging -import numpy as np from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union import k2 +import numpy as np import torch import torch.multiprocessing as mp import torch.nn as nn from lhotse.cut import Cut from lhotse.utils import fix_random_seed -from torch.optim import Optimizer +from tokenizer import Tokenizer from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import LJSpeechTtsDataModule +from utils import MetricsTracker, plot_feature, save_checkpoint +from vits import VITS from icefall import diagnostics from icefall.checkpoint import load_checkpoint @@ -41,11 +45,6 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, setup_logger, str2bool -from tokenizer import Tokenizer -from tts_datamodule import LJSpeechTtsDataModule -from utils import MetricsTracker, plot_feature, save_checkpoint -from vits import VITS - LRSchedulerType = torch.optim.lr_scheduler._LRScheduler @@ -296,6 +295,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): audio_lens = batch["audio_lens"].to(device) features_lens = batch["features_lens"].to(device) text = batch["text"] + speakers = batch["speakers"] tokens = tokenizer.texts_to_token_ids(text) tokens = k2.RaggedTensor(tokens) @@ -306,7 +306,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): # a tensor of shape (B, T) tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) - return audio, audio_lens, features, features_lens, tokens, tokens_lens + return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers def train_one_epoch( @@ -385,9 +385,15 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["text"]) - audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( - batch, tokenizer, device - ) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + speakers, + ) = prepare_input(batch, tokenizer, device) loss_info = MetricsTracker() loss_info["samples"] = batch_size diff --git a/egs/vctk/TTS/vits/tts_datamodule.py b/egs/vctk/TTS/vits/tts_datamodule.py index 40e9c19dd..d2064c5e3 100644 --- a/egs/vctk/TTS/vits/tts_datamodule.py +++ b/egs/vctk/TTS/vits/tts_datamodule.py @@ -29,10 +29,10 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, CutMix, DynamicBucketingSampler, - SpeechSynthesisDataset, PrecomputedFeatures, SimpleCutSampler, SpecAugment, + SpeechSynthesisDataset, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, diff --git a/egs/vctk/TTS/vits/utils.py b/egs/vctk/TTS/vits/utils.py index 12b2d6b81..6a067f596 100644 --- a/egs/vctk/TTS/vits/utils.py +++ b/egs/vctk/TTS/vits/utils.py @@ -14,15 +14,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union import collections import logging +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union import torch -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn from lhotse.dataset.sampling.base import CutSampler -from pathlib import Path from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer diff --git a/egs/vctk/TTS/vits/vits.py b/egs/vctk/TTS/vits/vits.py index c8bb38f30..2c38d5d37 100644 --- a/egs/vctk/TTS/vits/vits.py +++ b/egs/vctk/TTS/vits/vits.py @@ -9,8 +9,7 @@ from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn -from torch.cuda.amp import autocast - +from generator import VITSGenerator from hifigan import ( HiFiGANMultiPeriodDiscriminator, HiFiGANMultiScaleDiscriminator, @@ -25,9 +24,8 @@ from loss import ( KLDivergenceLoss, MelSpectrogramLoss, ) +from torch.cuda.amp import autocast from utils import get_segments -from generator import VITSGenerator - AVAILABLE_GENERATERS = { "vits_generator": VITSGenerator, diff --git a/egs/vctk/TTS/vits/wavenet.py b/egs/vctk/TTS/vits/wavenet.py index fbe1be52b..5db461d5c 100644 --- a/egs/vctk/TTS/vits/wavenet.py +++ b/egs/vctk/TTS/vits/wavenet.py @@ -9,9 +9,8 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. """ -import math import logging - +import math from typing import Optional, Tuple import torch