mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
isort formatted
This commit is contained in:
parent
269cc3b66a
commit
3c1b465d37
@ -32,9 +32,9 @@ from pathlib import Path
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
LilcomChunkyWriter,
|
||||
Spectrogram,
|
||||
SpectrogramConfig,
|
||||
LilcomChunkyWriter,
|
||||
load_manifest,
|
||||
)
|
||||
from lhotse.audio import RecordingSet
|
||||
|
@ -14,7 +14,6 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from flow import (
|
||||
ConvFlow,
|
||||
DilatedDepthSeparableConv,
|
||||
|
@ -13,7 +13,6 @@ import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from transform import piecewise_rational_quadratic_transform
|
||||
|
||||
|
||||
|
@ -16,9 +16,6 @@ from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
from duration_predictor import StochasticDurationPredictor
|
||||
from hifigan import HiFiGANGenerator
|
||||
from posterior_encoder import PosteriorEncoder
|
||||
@ -26,6 +23,8 @@ from residual_coupling import ResidualAffineCouplingBlock
|
||||
from text_encoder import TextEncoder
|
||||
from utils import get_random_segments
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
class VITSGenerator(torch.nn.Module):
|
||||
"""Generator module in VITS, `Conditional Variational Autoencoder
|
||||
|
@ -36,13 +36,12 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
|
||||
from train import get_model, get_params
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.utils import AttributeDict, setup_logger
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -14,7 +14,6 @@ from typing import List, Tuple, Union
|
||||
import torch
|
||||
import torch.distributions as D
|
||||
import torch.nn.functional as F
|
||||
|
||||
from lhotse.features.kaldi import Wav2LogFilterBank
|
||||
|
||||
|
||||
|
@ -12,9 +12,9 @@ This code is based on https://github.com/jaywalnut310/vits.
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from wavenet import Conv1d, WaveNet
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
from wavenet import WaveNet, Conv1d
|
||||
|
||||
|
||||
class PosteriorEncoder(torch.nn.Module):
|
||||
|
@ -12,7 +12,6 @@ This code is based on https://github.com/jaywalnut310/vits.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from flow import FlipFlow
|
||||
from wavenet import WaveNet
|
||||
|
||||
|
@ -18,21 +18,25 @@
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.optim import Optimizer
|
||||
from tokenizer import Tokenizer
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
from utils import MetricsTracker, plot_feature, save_checkpoint
|
||||
from vits import VITS
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
@ -41,11 +45,6 @@ from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
|
||||
from tokenizer import Tokenizer
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
from utils import MetricsTracker, plot_feature, save_checkpoint
|
||||
from vits import VITS
|
||||
|
||||
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
||||
|
||||
|
||||
@ -296,6 +295,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
||||
audio_lens = batch["audio_lens"].to(device)
|
||||
features_lens = batch["features_lens"].to(device)
|
||||
text = batch["text"]
|
||||
speakers = batch["speakers"]
|
||||
|
||||
tokens = tokenizer.texts_to_token_ids(text)
|
||||
tokens = k2.RaggedTensor(tokens)
|
||||
@ -306,7 +306,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
||||
# a tensor of shape (B, T)
|
||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
||||
|
||||
return audio, audio_lens, features, features_lens, tokens, tokens_lens
|
||||
return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
@ -385,9 +385,15 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
|
||||
batch_size = len(batch["text"])
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
||||
batch, tokenizer, device
|
||||
)
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
features,
|
||||
features_lens,
|
||||
tokens,
|
||||
tokens_lens,
|
||||
speakers,
|
||||
) = prepare_input(batch, tokenizer, device)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
|
@ -29,10 +29,10 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
SpeechSynthesisDataset,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
SpeechSynthesisDataset,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
|
@ -14,15 +14,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import collections
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from pathlib import Path
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
|
@ -9,8 +9,7 @@ from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from generator import VITSGenerator
|
||||
from hifigan import (
|
||||
HiFiGANMultiPeriodDiscriminator,
|
||||
HiFiGANMultiScaleDiscriminator,
|
||||
@ -25,9 +24,8 @@ from loss import (
|
||||
KLDivergenceLoss,
|
||||
MelSpectrogramLoss,
|
||||
)
|
||||
from torch.cuda.amp import autocast
|
||||
from utils import get_segments
|
||||
from generator import VITSGenerator
|
||||
|
||||
|
||||
AVAILABLE_GENERATERS = {
|
||||
"vits_generator": VITSGenerator,
|
||||
|
@ -9,9 +9,8 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
||||
|
||||
"""
|
||||
|
||||
import math
|
||||
import logging
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
Loading…
x
Reference in New Issue
Block a user