isort formatted

This commit is contained in:
jinzr 2023-11-10 11:04:46 +08:00
parent 269cc3b66a
commit 3c1b465d37
13 changed files with 30 additions and 33 deletions

View File

@ -32,9 +32,9 @@ from pathlib import Path
import torch
from lhotse import (
CutSet,
LilcomChunkyWriter,
Spectrogram,
SpectrogramConfig,
LilcomChunkyWriter,
load_manifest,
)
from lhotse.audio import RecordingSet

View File

@ -14,7 +14,6 @@ from typing import Optional
import torch
import torch.nn.functional as F
from flow import (
ConvFlow,
DilatedDepthSeparableConv,

View File

@ -13,7 +13,6 @@ import math
from typing import Optional, Tuple, Union
import torch
from transform import piecewise_rational_quadratic_transform

View File

@ -16,9 +16,6 @@ from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from icefall.utils import make_pad_mask
from duration_predictor import StochasticDurationPredictor
from hifigan import HiFiGANGenerator
from posterior_encoder import PosteriorEncoder
@ -26,6 +23,8 @@ from residual_coupling import ResidualAffineCouplingBlock
from text_encoder import TextEncoder
from utils import get_random_segments
from icefall.utils import make_pad_mask
class VITSGenerator(torch.nn.Module):
"""Generator module in VITS, `Conditional Variational Autoencoder

View File

@ -36,13 +36,12 @@ import k2
import torch
import torch.nn as nn
import torchaudio
from train import get_model, get_params
from tokenizer import Tokenizer
from train import get_model, get_params
from tts_datamodule import LJSpeechTtsDataModule
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger
from tts_datamodule import LJSpeechTtsDataModule
def get_parser():

View File

@ -14,7 +14,6 @@ from typing import List, Tuple, Union
import torch
import torch.distributions as D
import torch.nn.functional as F
from lhotse.features.kaldi import Wav2LogFilterBank

View File

@ -12,9 +12,9 @@ This code is based on https://github.com/jaywalnut310/vits.
from typing import Optional, Tuple
import torch
from wavenet import Conv1d, WaveNet
from icefall.utils import make_pad_mask
from wavenet import WaveNet, Conv1d
class PosteriorEncoder(torch.nn.Module):

View File

@ -12,7 +12,6 @@ This code is based on https://github.com/jaywalnut310/vits.
from typing import Optional, Tuple, Union
import torch
from flow import FlipFlow
from wavenet import WaveNet

View File

@ -18,21 +18,25 @@
import argparse
import logging
import numpy as np
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import k2
import numpy as np
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from torch.optim import Optimizer
from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import LJSpeechTtsDataModule
from utils import MetricsTracker, plot_feature, save_checkpoint
from vits import VITS
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint
@ -41,11 +45,6 @@ from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, setup_logger, str2bool
from tokenizer import Tokenizer
from tts_datamodule import LJSpeechTtsDataModule
from utils import MetricsTracker, plot_feature, save_checkpoint
from vits import VITS
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
@ -296,6 +295,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
text = batch["text"]
speakers = batch["speakers"]
tokens = tokenizer.texts_to_token_ids(text)
tokens = k2.RaggedTensor(tokens)
@ -306,7 +306,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
return audio, audio_lens, features, features_lens, tokens, tokens_lens
return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
def train_one_epoch(
@ -385,9 +385,15 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["text"])
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
batch, tokenizer, device
)
(
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
speakers,
) = prepare_input(batch, tokenizer, device)
loss_info = MetricsTracker()
loss_info["samples"] = batch_size

View File

@ -29,10 +29,10 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
SpeechSynthesisDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,

View File

@ -14,15 +14,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import collections
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn as nn
from lhotse.dataset.sampling.base import CutSampler
from pathlib import Path
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer

View File

@ -9,8 +9,7 @@ from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
from generator import VITSGenerator
from hifigan import (
HiFiGANMultiPeriodDiscriminator,
HiFiGANMultiScaleDiscriminator,
@ -25,9 +24,8 @@ from loss import (
KLDivergenceLoss,
MelSpectrogramLoss,
)
from torch.cuda.amp import autocast
from utils import get_segments
from generator import VITSGenerator
AVAILABLE_GENERATERS = {
"vits_generator": VITSGenerator,

View File

@ -9,9 +9,8 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
"""
import math
import logging
import math
from typing import Optional, Tuple
import torch