fixed formatting issue

This commit is contained in:
jinzr 2023-12-01 00:08:16 +08:00
parent 8c75259723
commit cf7ad8131d
16 changed files with 144 additions and 79 deletions

View File

@ -14,7 +14,6 @@ from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from flow import ( from flow import (
ConvFlow, ConvFlow,
DilatedDepthSeparableConv, DilatedDepthSeparableConv,

View File

@ -180,7 +180,13 @@ def export_model_onnx(
model_filename, model_filename,
verbose=False, verbose=False,
opset_version=opset_version, opset_version=opset_version,
input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"], input_names=[
"tokens",
"tokens_lens",
"noise_scale",
"noise_scale_dur",
"alpha",
],
output_names=["audio"], output_names=["audio"],
dynamic_axes={ dynamic_axes={
"tokens": {0: "N", 1: "T"}, "tokens": {0: "N", 1: "T"},

View File

@ -13,7 +13,6 @@ import math
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
from transform import piecewise_rational_quadratic_transform from transform import piecewise_rational_quadratic_transform

View File

@ -16,9 +16,6 @@ from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from icefall.utils import make_pad_mask
from duration_predictor import StochasticDurationPredictor from duration_predictor import StochasticDurationPredictor
from hifigan import HiFiGANGenerator from hifigan import HiFiGANGenerator
from posterior_encoder import PosteriorEncoder from posterior_encoder import PosteriorEncoder
@ -26,6 +23,8 @@ from residual_coupling import ResidualAffineCouplingBlock
from text_encoder import TextEncoder from text_encoder import TextEncoder
from utils import get_random_segments from utils import get_random_segments
from icefall.utils import make_pad_mask
class VITSGenerator(torch.nn.Module): class VITSGenerator(torch.nn.Module):
"""Generator module in VITS, `Conditional Variational Autoencoder """Generator module in VITS, `Conditional Variational Autoencoder

View File

@ -36,13 +36,12 @@ import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchaudio import torchaudio
from train import get_model, get_params
from tokenizer import Tokenizer from tokenizer import Tokenizer
from train import get_model, get_params
from tts_datamodule import LJSpeechTtsDataModule
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger from icefall.utils import AttributeDict, setup_logger
from tts_datamodule import LJSpeechTtsDataModule
def get_parser(): def get_parser():
@ -107,12 +106,12 @@ def infer_dataset(
for i in range(batch_size): for i in range(batch_size):
torchaudio.save( torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"), str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
audio[i:i + 1, :audio_lens[i]], audio[i : i + 1, : audio_lens[i]],
sample_rate=params.sampling_rate, sample_rate=params.sampling_rate,
) )
torchaudio.save( torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"), str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
audio_pred[i:i + 1, :audio_lens_pred[i]], audio_pred[i : i + 1, : audio_lens_pred[i]],
sample_rate=params.sampling_rate, sample_rate=params.sampling_rate,
) )
@ -144,14 +143,24 @@ def infer_dataset(
audio_lens = batch["audio_lens"].tolist() audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]] cut_ids = [cut.id for cut in batch["cut"]]
audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens) audio_pred, _, durations = model.inference_batch(
text=tokens, text_lengths=tokens_lens
)
audio_pred = audio_pred.detach().cpu() audio_pred = audio_pred.detach().cpu()
# convert to samples # convert to samples
audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() audio_lens_pred = (
(durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
)
futures.append( futures.append(
executor.submit( executor.submit(
_save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred _save_worker,
batch_size,
cut_ids,
audio,
audio_pred,
audio_lens,
audio_lens_pred,
) )
) )
@ -160,7 +169,9 @@ def infer_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
# return results # return results
for f in futures: for f in futures:
f.result() f.result()

View File

@ -14,7 +14,6 @@ from typing import List, Tuple, Union
import torch import torch
import torch.distributions as D import torch.distributions as D
import torch.nn.functional as F import torch.nn.functional as F
from lhotse.features.kaldi import Wav2LogFilterBank from lhotse.features.kaldi import Wav2LogFilterBank

View File

@ -12,9 +12,9 @@ This code is based on https://github.com/jaywalnut310/vits.
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
from wavenet import Conv1d, WaveNet
from icefall.utils import make_pad_mask from icefall.utils import make_pad_mask
from wavenet import WaveNet, Conv1d
class PosteriorEncoder(torch.nn.Module): class PosteriorEncoder(torch.nn.Module):

View File

@ -12,7 +12,6 @@ This code is based on https://github.com/jaywalnut310/vits.
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
from flow import FlipFlow from flow import FlipFlow
from wavenet import WaveNet from wavenet import WaveNet

View File

@ -28,10 +28,10 @@ Use the onnx model to generate a wav:
import argparse import argparse
import logging import logging
import onnxruntime as ort import onnxruntime as ort
import torch import torch
import torchaudio import torchaudio
from tokenizer import Tokenizer from tokenizer import Tokenizer

View File

@ -169,9 +169,7 @@ class Transformer(nn.Module):
x, pos_emb = self.encoder_pos(x) x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
x = self.encoder( x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C)
x, pos_emb, key_padding_mask=key_padding_mask
) # (T, N, C)
x = self.after_norm(x) x = self.after_norm(x)
@ -207,7 +205,9 @@ class TransformerEncoderLayer(nn.Module):
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
) )
self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout) self.self_attn = RelPositionMultiheadAttention(
d_model, num_heads, dropout=dropout
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
@ -242,7 +242,9 @@ class TransformerEncoderLayer(nn.Module):
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
""" """
# macaron style feed-forward module # macaron style feed-forward module
src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src))) src = src + self.ff_scale * self.dropout(
self.feed_forward_macaron(self.norm_ff_macaron(src))
)
# multi-head self-attention module # multi-head self-attention module
src_attn = self.self_attn( src_attn = self.self_attn(
@ -490,11 +492,17 @@ class RelPositionMultiheadAttention(nn.Module):
q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1) v = (
v.contiguous()
.view(seq_len, batch_size * self.num_heads, self.head_dim)
.transpose(0, 1)
)
q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim) q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim)
p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim) p = self.linear_pos(pos_emb).view(
pos_emb.size(0), -1, self.num_heads, self.head_dim
)
# (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1) # (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1)
p = p.permute(0, 2, 3, 1) p = p.permute(0, 2, 3, 1)
@ -506,15 +514,23 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c # first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len) k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len)
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len) matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch_size, num_head, seq_len, seq_len)
# compute matrix b and matrix d # compute matrix b and matrix d
matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1) matrix_bd = torch.matmul(
matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len) q_with_bias_v, p
) # (batch_size, num_head, seq_len, 2*seq_len-1)
matrix_bd = self.rel_shift(
matrix_bd
) # (batch_size, num_head, seq_len, seq_len)
# (batch_size, num_head, seq_len, seq_len) # (batch_size, num_head, seq_len, seq_len)
attn_output_weights = (matrix_ac + matrix_bd) * scaling attn_output_weights = (matrix_ac + matrix_bd) * scaling
attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len) attn_output_weights = attn_output_weights.view(
batch_size * self.num_heads, seq_len, seq_len
)
if key_padding_mask is not None: if key_padding_mask is not None:
assert key_padding_mask.shape == (batch_size, seq_len) assert key_padding_mask.shape == (batch_size, seq_len)
@ -536,10 +552,16 @@ class RelPositionMultiheadAttention(nn.Module):
# (batch_size * num_head, seq_len, head_dim) # (batch_size * num_head, seq_len, head_dim)
attn_output = torch.bmm(attn_output_weights, v) attn_output = torch.bmm(attn_output_weights, v)
assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim) assert attn_output.shape == (
batch_size * self.num_heads,
seq_len,
self.head_dim,
)
attn_output = ( attn_output = (
attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim) attn_output.transpose(0, 1)
.contiguous()
.view(seq_len, batch_size, self.embed_dim)
) )
# (seq_len, batch_size, embed_dim) # (seq_len, batch_size, embed_dim)
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)

View File

@ -78,7 +78,9 @@ class Tokenizer(object):
return token_ids_list return token_ids_list
def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True): def tokens_to_token_ids(
self, tokens_list: List[str], intersperse_blank: bool = True
):
""" """
Args: Args:
tokens_list: tokens_list:

View File

@ -18,21 +18,25 @@
import argparse import argparse
import logging import logging
import numpy as np
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import k2 import k2
import numpy as np
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch.optim import Optimizer from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import LJSpeechTtsDataModule
from utils import MetricsTracker, plot_feature, save_checkpoint
from vits import VITS
from icefall import diagnostics from icefall import diagnostics
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
@ -41,11 +45,6 @@ from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, setup_logger, str2bool from icefall.utils import AttributeDict, setup_logger, str2bool
from tokenizer import Tokenizer
from tts_datamodule import LJSpeechTtsDataModule
from utils import MetricsTracker, plot_feature, save_checkpoint
from vits import VITS
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
@ -385,11 +384,12 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["tokens"]) batch_size = len(batch["tokens"])
audio, audio_lens, features, features_lens, tokens, tokens_lens = \ audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
prepare_input(batch, tokenizer, device) batch, tokenizer, device
)
loss_info = MetricsTracker() loss_info = MetricsTracker()
loss_info['samples'] = batch_size loss_info["samples"] = batch_size
try: try:
with autocast(enabled=params.use_fp16): with autocast(enabled=params.use_fp16):
@ -446,7 +446,9 @@ def train_one_epoch(
# behavior depending on the current grad scale. # behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item() cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0): if cur_grad_scale < 8.0 or (
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
):
scaler.update(cur_grad_scale * 2.0) scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
if not saved_bad_model: if not saved_bad_model:
@ -482,9 +484,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if params.use_fp16: if params.use_fp16:
tb_writer.add_scalar( tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train "train/grad_scale", cur_grad_scale, params.batch_idx_train
@ -492,19 +492,34 @@ def train_one_epoch(
if "returned_sample" in stats_g: if "returned_sample" in stats_g:
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
tb_writer.add_audio( tb_writer.add_audio(
"train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate "train/speech_hat_",
speech_hat_,
params.batch_idx_train,
params.sampling_rate,
) )
tb_writer.add_audio( tb_writer.add_audio(
"train/speech_", speech_, params.batch_idx_train, params.sampling_rate "train/speech_",
speech_,
params.batch_idx_train,
params.sampling_rate,
) )
tb_writer.add_image( tb_writer.add_image(
"train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC' "train/mel_hat_",
plot_feature(mel_hat_),
params.batch_idx_train,
dataformats="HWC",
) )
tb_writer.add_image( tb_writer.add_image(
"train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC' "train/mel_",
plot_feature(mel_),
params.batch_idx_train,
dataformats="HWC",
) )
if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics: if (
params.batch_idx_train % params.valid_interval == 0
and not params.print_diagnostics
):
logging.info("Computing validation loss") logging.info("Computing validation loss")
valid_info, (speech_hat, speech) = compute_validation_loss( valid_info, (speech_hat, speech) = compute_validation_loss(
params=params, params=params,
@ -523,10 +538,16 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
) )
tb_writer.add_audio( tb_writer.add_audio(
"train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate "train/valdi_speech_hat",
speech_hat,
params.batch_idx_train,
params.sampling_rate,
) )
tb_writer.add_audio( tb_writer.add_audio(
"train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate "train/valdi_speech",
speech,
params.batch_idx_train,
params.sampling_rate,
) )
loss_value = tot_loss["generator_loss"] / tot_loss["samples"] loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
@ -555,11 +576,17 @@ def compute_validation_loss(
with torch.no_grad(): with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
batch_size = len(batch["tokens"]) batch_size = len(batch["tokens"])
audio, audio_lens, features, features_lens, tokens, tokens_lens = \ (
prepare_input(batch, tokenizer, device) audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
) = prepare_input(batch, tokenizer, device)
loss_info = MetricsTracker() loss_info = MetricsTracker()
loss_info['samples'] = batch_size loss_info["samples"] = batch_size
# forward discriminator # forward discriminator
loss_d, stats_d = model( loss_d, stats_d = model(
@ -596,12 +623,17 @@ def compute_validation_loss(
if batch_idx == 0 and rank == 0: if batch_idx == 0 and rank == 0:
inner_model = model.module if isinstance(model, DDP) else model inner_model = model.module if isinstance(model, DDP) else model
audio_pred, _, duration = inner_model.inference( audio_pred, _, duration = inner_model.inference(
text=tokens[0, :tokens_lens[0].item()] text=tokens[0, : tokens_lens[0].item()]
) )
audio_pred = audio_pred.data.cpu().numpy() audio_pred = audio_pred.data.cpu().numpy()
audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() audio_len_pred = (
assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred)) (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy() )
assert audio_len_pred == len(audio_pred), (
audio_len_pred,
len(audio_pred),
)
audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
returned_sample = (audio_pred, audio_gt) returned_sample = (audio_pred, audio_gt)
if world_size > 1: if world_size > 1:
@ -632,8 +664,9 @@ def scan_pessimistic_batches_for_oom(
batches, crit_values = find_pessimistic_batches(train_dl.sampler) batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
audio, audio_lens, features, features_lens, tokens, tokens_lens = \ audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
prepare_input(batch, tokenizer, device) batch, tokenizer, device
)
try: try:
# for discriminator # for discriminator
with autocast(enabled=params.use_fp16): with autocast(enabled=params.use_fp16):

View File

@ -29,10 +29,10 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate, CutConcatenate,
CutMix, CutMix,
DynamicBucketingSampler, DynamicBucketingSampler,
SpeechSynthesisDataset,
PrecomputedFeatures, PrecomputedFeatures,
SimpleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
SpeechSynthesisDataset,
) )
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples, AudioSamples,

View File

@ -14,15 +14,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import collections import collections
import logging import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from pathlib import Path
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
@ -97,23 +97,23 @@ def plot_feature(spectrogram):
global MATPLOTLIB_FLAG global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG: if not MATPLOTLIB_FLAG:
import matplotlib import matplotlib
matplotlib.use("Agg") matplotlib.use("Agg")
MATPLOTLIB_FLAG = True MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib') mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING) mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt import matplotlib.pylab as plt
import numpy as np import numpy as np
fig, ax = plt.subplots(figsize=(10, 2)) fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
interpolation='none')
plt.colorbar(im, ax=ax) plt.colorbar(im, ax=ax)
plt.xlabel("Frames") plt.xlabel("Frames")
plt.ylabel("Channels") plt.ylabel("Channels")
plt.tight_layout() plt.tight_layout()
fig.canvas.draw() fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close() plt.close()
return data return data

View File

@ -9,8 +9,7 @@ from typing import Any, Dict, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.cuda.amp import autocast from generator import VITSGenerator
from hifigan import ( from hifigan import (
HiFiGANMultiPeriodDiscriminator, HiFiGANMultiPeriodDiscriminator,
HiFiGANMultiScaleDiscriminator, HiFiGANMultiScaleDiscriminator,
@ -25,9 +24,8 @@ from loss import (
KLDivergenceLoss, KLDivergenceLoss,
MelSpectrogramLoss, MelSpectrogramLoss,
) )
from torch.cuda.amp import autocast
from utils import get_segments from utils import get_segments
from generator import VITSGenerator
AVAILABLE_GENERATERS = { AVAILABLE_GENERATERS = {
"vits_generator": VITSGenerator, "vits_generator": VITSGenerator,
@ -42,8 +40,7 @@ AVAILABLE_DISCRIMINATORS = {
class VITS(nn.Module): class VITS(nn.Module):
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech` """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
"""
def __init__( def __init__(
self, self,

View File

@ -9,9 +9,8 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
""" """
import math
import logging import logging
import math
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch