isort formatted

This commit is contained in:
jinzr 2023-11-10 11:04:46 +08:00
parent 269cc3b66a
commit 3c1b465d37
13 changed files with 30 additions and 33 deletions

View File

@ -32,9 +32,9 @@ from pathlib import Path
import torch import torch
from lhotse import ( from lhotse import (
CutSet, CutSet,
LilcomChunkyWriter,
Spectrogram, Spectrogram,
SpectrogramConfig, SpectrogramConfig,
LilcomChunkyWriter,
load_manifest, load_manifest,
) )
from lhotse.audio import RecordingSet from lhotse.audio import RecordingSet

View File

@ -14,7 +14,6 @@ from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from flow import ( from flow import (
ConvFlow, ConvFlow,
DilatedDepthSeparableConv, DilatedDepthSeparableConv,

View File

@ -13,7 +13,6 @@ import math
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
from transform import piecewise_rational_quadratic_transform from transform import piecewise_rational_quadratic_transform

View File

@ -16,9 +16,6 @@ from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from icefall.utils import make_pad_mask
from duration_predictor import StochasticDurationPredictor from duration_predictor import StochasticDurationPredictor
from hifigan import HiFiGANGenerator from hifigan import HiFiGANGenerator
from posterior_encoder import PosteriorEncoder from posterior_encoder import PosteriorEncoder
@ -26,6 +23,8 @@ from residual_coupling import ResidualAffineCouplingBlock
from text_encoder import TextEncoder from text_encoder import TextEncoder
from utils import get_random_segments from utils import get_random_segments
from icefall.utils import make_pad_mask
class VITSGenerator(torch.nn.Module): class VITSGenerator(torch.nn.Module):
"""Generator module in VITS, `Conditional Variational Autoencoder """Generator module in VITS, `Conditional Variational Autoencoder

View File

@ -36,13 +36,12 @@ import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchaudio import torchaudio
from train import get_model, get_params
from tokenizer import Tokenizer from tokenizer import Tokenizer
from train import get_model, get_params
from tts_datamodule import LJSpeechTtsDataModule
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger from icefall.utils import AttributeDict, setup_logger
from tts_datamodule import LJSpeechTtsDataModule
def get_parser(): def get_parser():

View File

@ -14,7 +14,6 @@ from typing import List, Tuple, Union
import torch import torch
import torch.distributions as D import torch.distributions as D
import torch.nn.functional as F import torch.nn.functional as F
from lhotse.features.kaldi import Wav2LogFilterBank from lhotse.features.kaldi import Wav2LogFilterBank

View File

@ -12,9 +12,9 @@ This code is based on https://github.com/jaywalnut310/vits.
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
from wavenet import Conv1d, WaveNet
from icefall.utils import make_pad_mask from icefall.utils import make_pad_mask
from wavenet import WaveNet, Conv1d
class PosteriorEncoder(torch.nn.Module): class PosteriorEncoder(torch.nn.Module):

View File

@ -12,7 +12,6 @@ This code is based on https://github.com/jaywalnut310/vits.
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
from flow import FlipFlow from flow import FlipFlow
from wavenet import WaveNet from wavenet import WaveNet

View File

@ -18,21 +18,25 @@
import argparse import argparse
import logging import logging
import numpy as np
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import k2 import k2
import numpy as np
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch.optim import Optimizer from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import LJSpeechTtsDataModule
from utils import MetricsTracker, plot_feature, save_checkpoint
from vits import VITS
from icefall import diagnostics from icefall import diagnostics
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
@ -41,11 +45,6 @@ from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, setup_logger, str2bool from icefall.utils import AttributeDict, setup_logger, str2bool
from tokenizer import Tokenizer
from tts_datamodule import LJSpeechTtsDataModule
from utils import MetricsTracker, plot_feature, save_checkpoint
from vits import VITS
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
@ -296,6 +295,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
audio_lens = batch["audio_lens"].to(device) audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device) features_lens = batch["features_lens"].to(device)
text = batch["text"] text = batch["text"]
speakers = batch["speakers"]
tokens = tokenizer.texts_to_token_ids(text) tokens = tokenizer.texts_to_token_ids(text)
tokens = k2.RaggedTensor(tokens) tokens = k2.RaggedTensor(tokens)
@ -306,7 +306,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
# a tensor of shape (B, T) # a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
return audio, audio_lens, features, features_lens, tokens, tokens_lens return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
def train_one_epoch( def train_one_epoch(
@ -385,9 +385,15 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["text"]) batch_size = len(batch["text"])
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( (
batch, tokenizer, device audio,
) audio_lens,
features,
features_lens,
tokens,
tokens_lens,
speakers,
) = prepare_input(batch, tokenizer, device)
loss_info = MetricsTracker() loss_info = MetricsTracker()
loss_info["samples"] = batch_size loss_info["samples"] = batch_size

View File

@ -29,10 +29,10 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate, CutConcatenate,
CutMix, CutMix,
DynamicBucketingSampler, DynamicBucketingSampler,
SpeechSynthesisDataset,
PrecomputedFeatures, PrecomputedFeatures,
SimpleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
SpeechSynthesisDataset,
) )
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples, AudioSamples,

View File

@ -14,15 +14,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import collections import collections
import logging import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from pathlib import Path
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer

View File

@ -9,8 +9,7 @@ from typing import Any, Dict, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.cuda.amp import autocast from generator import VITSGenerator
from hifigan import ( from hifigan import (
HiFiGANMultiPeriodDiscriminator, HiFiGANMultiPeriodDiscriminator,
HiFiGANMultiScaleDiscriminator, HiFiGANMultiScaleDiscriminator,
@ -25,9 +24,8 @@ from loss import (
KLDivergenceLoss, KLDivergenceLoss,
MelSpectrogramLoss, MelSpectrogramLoss,
) )
from torch.cuda.amp import autocast
from utils import get_segments from utils import get_segments
from generator import VITSGenerator
AVAILABLE_GENERATERS = { AVAILABLE_GENERATERS = {
"vits_generator": VITSGenerator, "vits_generator": VITSGenerator,

View File

@ -9,9 +9,8 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
""" """
import math
import logging import logging
import math
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch