mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
fixed formatting issue
This commit is contained in:
parent
8c75259723
commit
cf7ad8131d
@ -14,7 +14,6 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from flow import (
|
||||
ConvFlow,
|
||||
DilatedDepthSeparableConv,
|
||||
|
@ -180,7 +180,13 @@ def export_model_onnx(
|
||||
model_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"],
|
||||
input_names=[
|
||||
"tokens",
|
||||
"tokens_lens",
|
||||
"noise_scale",
|
||||
"noise_scale_dur",
|
||||
"alpha",
|
||||
],
|
||||
output_names=["audio"],
|
||||
dynamic_axes={
|
||||
"tokens": {0: "N", 1: "T"},
|
||||
|
@ -13,7 +13,6 @@ import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from transform import piecewise_rational_quadratic_transform
|
||||
|
||||
|
||||
|
@ -16,9 +16,6 @@ from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
from duration_predictor import StochasticDurationPredictor
|
||||
from hifigan import HiFiGANGenerator
|
||||
from posterior_encoder import PosteriorEncoder
|
||||
@ -26,6 +23,8 @@ from residual_coupling import ResidualAffineCouplingBlock
|
||||
from text_encoder import TextEncoder
|
||||
from utils import get_random_segments
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
class VITSGenerator(torch.nn.Module):
|
||||
"""Generator module in VITS, `Conditional Variational Autoencoder
|
||||
|
@ -36,13 +36,12 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
|
||||
from train import get_model, get_params
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.utils import AttributeDict, setup_logger
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -107,12 +106,12 @@ def infer_dataset(
|
||||
for i in range(batch_size):
|
||||
torchaudio.save(
|
||||
str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
|
||||
audio[i:i + 1, :audio_lens[i]],
|
||||
audio[i : i + 1, : audio_lens[i]],
|
||||
sample_rate=params.sampling_rate,
|
||||
)
|
||||
torchaudio.save(
|
||||
str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
|
||||
audio_pred[i:i + 1, :audio_lens_pred[i]],
|
||||
audio_pred[i : i + 1, : audio_lens_pred[i]],
|
||||
sample_rate=params.sampling_rate,
|
||||
)
|
||||
|
||||
@ -144,14 +143,24 @@ def infer_dataset(
|
||||
audio_lens = batch["audio_lens"].tolist()
|
||||
cut_ids = [cut.id for cut in batch["cut"]]
|
||||
|
||||
audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens)
|
||||
audio_pred, _, durations = model.inference_batch(
|
||||
text=tokens, text_lengths=tokens_lens
|
||||
)
|
||||
audio_pred = audio_pred.detach().cpu()
|
||||
# convert to samples
|
||||
audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
|
||||
audio_lens_pred = (
|
||||
(durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
|
||||
)
|
||||
|
||||
futures.append(
|
||||
executor.submit(
|
||||
_save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred
|
||||
_save_worker,
|
||||
batch_size,
|
||||
cut_ids,
|
||||
audio,
|
||||
audio_pred,
|
||||
audio_lens,
|
||||
audio_lens_pred,
|
||||
)
|
||||
)
|
||||
|
||||
@ -160,7 +169,9 @@ def infer_dataset(
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
# return results
|
||||
for f in futures:
|
||||
f.result()
|
||||
|
@ -14,7 +14,6 @@ from typing import List, Tuple, Union
|
||||
import torch
|
||||
import torch.distributions as D
|
||||
import torch.nn.functional as F
|
||||
|
||||
from lhotse.features.kaldi import Wav2LogFilterBank
|
||||
|
||||
|
||||
|
@ -12,9 +12,9 @@ This code is based on https://github.com/jaywalnut310/vits.
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from wavenet import Conv1d, WaveNet
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
from wavenet import WaveNet, Conv1d
|
||||
|
||||
|
||||
class PosteriorEncoder(torch.nn.Module):
|
||||
|
@ -12,7 +12,6 @@ This code is based on https://github.com/jaywalnut310/vits.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from flow import FlipFlow
|
||||
from wavenet import WaveNet
|
||||
|
||||
|
@ -28,10 +28,10 @@ Use the onnx model to generate a wav:
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
|
||||
|
@ -169,9 +169,7 @@ class Transformer(nn.Module):
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
x = self.encoder(
|
||||
x, pos_emb, key_padding_mask=key_padding_mask
|
||||
) # (T, N, C)
|
||||
x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C)
|
||||
|
||||
x = self.after_norm(x)
|
||||
|
||||
@ -207,7 +205,9 @@ class TransformerEncoderLayer(nn.Module):
|
||||
nn.Linear(dim_feedforward, d_model),
|
||||
)
|
||||
|
||||
self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout)
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
d_model, num_heads, dropout=dropout
|
||||
)
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
|
||||
@ -242,7 +242,9 @@ class TransformerEncoderLayer(nn.Module):
|
||||
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
|
||||
"""
|
||||
# macaron style feed-forward module
|
||||
src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src)))
|
||||
src = src + self.ff_scale * self.dropout(
|
||||
self.feed_forward_macaron(self.norm_ff_macaron(src))
|
||||
)
|
||||
|
||||
# multi-head self-attention module
|
||||
src_attn = self.self_attn(
|
||||
@ -490,11 +492,17 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
|
||||
k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
|
||||
v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
|
||||
v = (
|
||||
v.contiguous()
|
||||
.view(seq_len, batch_size * self.num_heads, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim)
|
||||
|
||||
p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim)
|
||||
p = self.linear_pos(pos_emb).view(
|
||||
pos_emb.size(0), -1, self.num_heads, self.head_dim
|
||||
)
|
||||
# (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1)
|
||||
p = p.permute(0, 2, 3, 1)
|
||||
|
||||
@ -506,15 +514,23 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
# first compute matrix a and matrix c
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len)
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len)
|
||||
matrix_ac = torch.matmul(
|
||||
q_with_bias_u, k
|
||||
) # (batch_size, num_head, seq_len, seq_len)
|
||||
|
||||
# compute matrix b and matrix d
|
||||
matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1)
|
||||
matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len)
|
||||
matrix_bd = torch.matmul(
|
||||
q_with_bias_v, p
|
||||
) # (batch_size, num_head, seq_len, 2*seq_len-1)
|
||||
matrix_bd = self.rel_shift(
|
||||
matrix_bd
|
||||
) # (batch_size, num_head, seq_len, seq_len)
|
||||
|
||||
# (batch_size, num_head, seq_len, seq_len)
|
||||
attn_output_weights = (matrix_ac + matrix_bd) * scaling
|
||||
attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
batch_size * self.num_heads, seq_len, seq_len
|
||||
)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape == (batch_size, seq_len)
|
||||
@ -536,10 +552,16 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
# (batch_size * num_head, seq_len, head_dim)
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim)
|
||||
assert attn_output.shape == (
|
||||
batch_size * self.num_heads,
|
||||
seq_len,
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim)
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(seq_len, batch_size, self.embed_dim)
|
||||
)
|
||||
# (seq_len, batch_size, embed_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
@ -78,7 +78,9 @@ class Tokenizer(object):
|
||||
|
||||
return token_ids_list
|
||||
|
||||
def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True):
|
||||
def tokens_to_token_ids(
|
||||
self, tokens_list: List[str], intersperse_blank: bool = True
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
tokens_list:
|
||||
|
@ -18,21 +18,25 @@
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.optim import Optimizer
|
||||
from tokenizer import Tokenizer
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
from utils import MetricsTracker, plot_feature, save_checkpoint
|
||||
from vits import VITS
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
@ -41,11 +45,6 @@ from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
|
||||
from tokenizer import Tokenizer
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
from utils import MetricsTracker, plot_feature, save_checkpoint
|
||||
from vits import VITS
|
||||
|
||||
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
||||
|
||||
|
||||
@ -385,11 +384,12 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
|
||||
batch_size = len(batch["tokens"])
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
||||
prepare_input(batch, tokenizer, device)
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
||||
batch, tokenizer, device
|
||||
)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info['samples'] = batch_size
|
||||
loss_info["samples"] = batch_size
|
||||
|
||||
try:
|
||||
with autocast(enabled=params.use_fp16):
|
||||
@ -446,7 +446,9 @@ def train_one_epoch(
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
|
||||
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0):
|
||||
if cur_grad_scale < 8.0 or (
|
||||
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
|
||||
):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
if cur_grad_scale < 0.01:
|
||||
if not saved_bad_model:
|
||||
@ -482,9 +484,7 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
if params.use_fp16:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
@ -492,19 +492,34 @@ def train_one_epoch(
|
||||
if "returned_sample" in stats_g:
|
||||
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
|
||||
tb_writer.add_audio(
|
||||
"train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate
|
||||
"train/speech_hat_",
|
||||
speech_hat_,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/speech_", speech_, params.batch_idx_train, params.sampling_rate
|
||||
"train/speech_",
|
||||
speech_,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
tb_writer.add_image(
|
||||
"train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC'
|
||||
"train/mel_hat_",
|
||||
plot_feature(mel_hat_),
|
||||
params.batch_idx_train,
|
||||
dataformats="HWC",
|
||||
)
|
||||
tb_writer.add_image(
|
||||
"train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC'
|
||||
"train/mel_",
|
||||
plot_feature(mel_),
|
||||
params.batch_idx_train,
|
||||
dataformats="HWC",
|
||||
)
|
||||
|
||||
if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
if (
|
||||
params.batch_idx_train % params.valid_interval == 0
|
||||
and not params.print_diagnostics
|
||||
):
|
||||
logging.info("Computing validation loss")
|
||||
valid_info, (speech_hat, speech) = compute_validation_loss(
|
||||
params=params,
|
||||
@ -523,10 +538,16 @@ def train_one_epoch(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate
|
||||
"train/valdi_speech_hat",
|
||||
speech_hat,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate
|
||||
"train/valdi_speech",
|
||||
speech,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
|
||||
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||
@ -555,11 +576,17 @@ def compute_validation_loss(
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
batch_size = len(batch["tokens"])
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
||||
prepare_input(batch, tokenizer, device)
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
features,
|
||||
features_lens,
|
||||
tokens,
|
||||
tokens_lens,
|
||||
) = prepare_input(batch, tokenizer, device)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info['samples'] = batch_size
|
||||
loss_info["samples"] = batch_size
|
||||
|
||||
# forward discriminator
|
||||
loss_d, stats_d = model(
|
||||
@ -596,12 +623,17 @@ def compute_validation_loss(
|
||||
if batch_idx == 0 and rank == 0:
|
||||
inner_model = model.module if isinstance(model, DDP) else model
|
||||
audio_pred, _, duration = inner_model.inference(
|
||||
text=tokens[0, :tokens_lens[0].item()]
|
||||
text=tokens[0, : tokens_lens[0].item()]
|
||||
)
|
||||
audio_pred = audio_pred.data.cpu().numpy()
|
||||
audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
||||
assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred))
|
||||
audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy()
|
||||
audio_len_pred = (
|
||||
(duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
||||
)
|
||||
assert audio_len_pred == len(audio_pred), (
|
||||
audio_len_pred,
|
||||
len(audio_pred),
|
||||
)
|
||||
audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
|
||||
returned_sample = (audio_pred, audio_gt)
|
||||
|
||||
if world_size > 1:
|
||||
@ -632,8 +664,9 @@ def scan_pessimistic_batches_for_oom(
|
||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
||||
prepare_input(batch, tokenizer, device)
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
||||
batch, tokenizer, device
|
||||
)
|
||||
try:
|
||||
# for discriminator
|
||||
with autocast(enabled=params.use_fp16):
|
||||
|
@ -29,10 +29,10 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
SpeechSynthesisDataset,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
SpeechSynthesisDataset,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
|
@ -14,15 +14,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import collections
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from pathlib import Path
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
@ -97,23 +97,23 @@ def plot_feature(spectrogram):
|
||||
global MATPLOTLIB_FLAG
|
||||
if not MATPLOTLIB_FLAG:
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
MATPLOTLIB_FLAG = True
|
||||
mpl_logger = logging.getLogger('matplotlib')
|
||||
mpl_logger = logging.getLogger("matplotlib")
|
||||
mpl_logger.setLevel(logging.WARNING)
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
||||
interpolation='none')
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
||||
plt.colorbar(im, ax=ax)
|
||||
plt.xlabel("Frames")
|
||||
plt.ylabel("Channels")
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
||||
|
@ -9,8 +9,7 @@ from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from generator import VITSGenerator
|
||||
from hifigan import (
|
||||
HiFiGANMultiPeriodDiscriminator,
|
||||
HiFiGANMultiScaleDiscriminator,
|
||||
@ -25,9 +24,8 @@ from loss import (
|
||||
KLDivergenceLoss,
|
||||
MelSpectrogramLoss,
|
||||
)
|
||||
from torch.cuda.amp import autocast
|
||||
from utils import get_segments
|
||||
from generator import VITSGenerator
|
||||
|
||||
|
||||
AVAILABLE_GENERATERS = {
|
||||
"vits_generator": VITSGenerator,
|
||||
@ -42,8 +40,7 @@ AVAILABLE_DISCRIMINATORS = {
|
||||
|
||||
|
||||
class VITS(nn.Module):
|
||||
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`
|
||||
"""
|
||||
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -9,9 +9,8 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
||||
|
||||
"""
|
||||
|
||||
import math
|
||||
import logging
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
Loading…
x
Reference in New Issue
Block a user