mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 15:44:17 +00:00
isort formatted
This commit is contained in:
parent
269cc3b66a
commit
3c1b465d37
@ -32,9 +32,9 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
from lhotse import (
|
from lhotse import (
|
||||||
CutSet,
|
CutSet,
|
||||||
|
LilcomChunkyWriter,
|
||||||
Spectrogram,
|
Spectrogram,
|
||||||
SpectrogramConfig,
|
SpectrogramConfig,
|
||||||
LilcomChunkyWriter,
|
|
||||||
load_manifest,
|
load_manifest,
|
||||||
)
|
)
|
||||||
from lhotse.audio import RecordingSet
|
from lhotse.audio import RecordingSet
|
||||||
|
@ -14,7 +14,6 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from flow import (
|
from flow import (
|
||||||
ConvFlow,
|
ConvFlow,
|
||||||
DilatedDepthSeparableConv,
|
DilatedDepthSeparableConv,
|
||||||
|
@ -13,7 +13,6 @@ import math
|
|||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transform import piecewise_rational_quadratic_transform
|
from transform import piecewise_rational_quadratic_transform
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,9 +16,6 @@ from typing import List, Optional, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
|
||||||
|
|
||||||
from duration_predictor import StochasticDurationPredictor
|
from duration_predictor import StochasticDurationPredictor
|
||||||
from hifigan import HiFiGANGenerator
|
from hifigan import HiFiGANGenerator
|
||||||
from posterior_encoder import PosteriorEncoder
|
from posterior_encoder import PosteriorEncoder
|
||||||
@ -26,6 +23,8 @@ from residual_coupling import ResidualAffineCouplingBlock
|
|||||||
from text_encoder import TextEncoder
|
from text_encoder import TextEncoder
|
||||||
from utils import get_random_segments
|
from utils import get_random_segments
|
||||||
|
|
||||||
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
class VITSGenerator(torch.nn.Module):
|
class VITSGenerator(torch.nn.Module):
|
||||||
"""Generator module in VITS, `Conditional Variational Autoencoder
|
"""Generator module in VITS, `Conditional Variational Autoencoder
|
||||||
|
@ -36,13 +36,12 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
from train import get_model, get_params
|
|
||||||
from tokenizer import Tokenizer
|
from tokenizer import Tokenizer
|
||||||
|
from train import get_model, get_params
|
||||||
|
from tts_datamodule import LJSpeechTtsDataModule
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.utils import AttributeDict, setup_logger
|
from icefall.utils import AttributeDict, setup_logger
|
||||||
from tts_datamodule import LJSpeechTtsDataModule
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -14,7 +14,6 @@ from typing import List, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributions as D
|
import torch.distributions as D
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from lhotse.features.kaldi import Wav2LogFilterBank
|
from lhotse.features.kaldi import Wav2LogFilterBank
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,9 +12,9 @@ This code is based on https://github.com/jaywalnut310/vits.
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from wavenet import Conv1d, WaveNet
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
from wavenet import WaveNet, Conv1d
|
|
||||||
|
|
||||||
|
|
||||||
class PosteriorEncoder(torch.nn.Module):
|
class PosteriorEncoder(torch.nn.Module):
|
||||||
|
@ -12,7 +12,6 @@ This code is based on https://github.com/jaywalnut310/vits.
|
|||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from flow import FlipFlow
|
from flow import FlipFlow
|
||||||
from wavenet import WaveNet
|
from wavenet import WaveNet
|
||||||
|
|
||||||
|
@ -18,21 +18,25 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch.optim import Optimizer
|
from tokenizer import Tokenizer
|
||||||
from torch.cuda.amp import GradScaler, autocast
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.optim import Optimizer
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tts_datamodule import LJSpeechTtsDataModule
|
||||||
|
from utils import MetricsTracker, plot_feature, save_checkpoint
|
||||||
|
from vits import VITS
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
@ -41,11 +45,6 @@ from icefall.env import get_env_info
|
|||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||||
|
|
||||||
from tokenizer import Tokenizer
|
|
||||||
from tts_datamodule import LJSpeechTtsDataModule
|
|
||||||
from utils import MetricsTracker, plot_feature, save_checkpoint
|
|
||||||
from vits import VITS
|
|
||||||
|
|
||||||
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
||||||
|
|
||||||
|
|
||||||
@ -296,6 +295,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
|||||||
audio_lens = batch["audio_lens"].to(device)
|
audio_lens = batch["audio_lens"].to(device)
|
||||||
features_lens = batch["features_lens"].to(device)
|
features_lens = batch["features_lens"].to(device)
|
||||||
text = batch["text"]
|
text = batch["text"]
|
||||||
|
speakers = batch["speakers"]
|
||||||
|
|
||||||
tokens = tokenizer.texts_to_token_ids(text)
|
tokens = tokenizer.texts_to_token_ids(text)
|
||||||
tokens = k2.RaggedTensor(tokens)
|
tokens = k2.RaggedTensor(tokens)
|
||||||
@ -306,7 +306,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
|||||||
# a tensor of shape (B, T)
|
# a tensor of shape (B, T)
|
||||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
||||||
|
|
||||||
return audio, audio_lens, features, features_lens, tokens, tokens_lens
|
return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
@ -385,9 +385,15 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
|
|
||||||
batch_size = len(batch["text"])
|
batch_size = len(batch["text"])
|
||||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
(
|
||||||
batch, tokenizer, device
|
audio,
|
||||||
)
|
audio_lens,
|
||||||
|
features,
|
||||||
|
features_lens,
|
||||||
|
tokens,
|
||||||
|
tokens_lens,
|
||||||
|
speakers,
|
||||||
|
) = prepare_input(batch, tokenizer, device)
|
||||||
|
|
||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
loss_info["samples"] = batch_size
|
loss_info["samples"] = batch_size
|
||||||
|
@ -29,10 +29,10 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
|||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
CutMix,
|
CutMix,
|
||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
SpeechSynthesisDataset,
|
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SimpleCutSampler,
|
SimpleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
|
SpeechSynthesisDataset,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
AudioSamples,
|
AudioSamples,
|
||||||
|
@ -14,15 +14,15 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from pathlib import Path
|
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
@ -9,8 +9,7 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.cuda.amp import autocast
|
from generator import VITSGenerator
|
||||||
|
|
||||||
from hifigan import (
|
from hifigan import (
|
||||||
HiFiGANMultiPeriodDiscriminator,
|
HiFiGANMultiPeriodDiscriminator,
|
||||||
HiFiGANMultiScaleDiscriminator,
|
HiFiGANMultiScaleDiscriminator,
|
||||||
@ -25,9 +24,8 @@ from loss import (
|
|||||||
KLDivergenceLoss,
|
KLDivergenceLoss,
|
||||||
MelSpectrogramLoss,
|
MelSpectrogramLoss,
|
||||||
)
|
)
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
from utils import get_segments
|
from utils import get_segments
|
||||||
from generator import VITSGenerator
|
|
||||||
|
|
||||||
|
|
||||||
AVAILABLE_GENERATERS = {
|
AVAILABLE_GENERATERS = {
|
||||||
"vits_generator": VITSGenerator,
|
"vits_generator": VITSGenerator,
|
||||||
|
@ -9,9 +9,8 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
Loading…
x
Reference in New Issue
Block a user