mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 15:44:17 +00:00
fixed formatting issue
This commit is contained in:
parent
8c75259723
commit
cf7ad8131d
@ -14,7 +14,6 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from flow import (
|
from flow import (
|
||||||
ConvFlow,
|
ConvFlow,
|
||||||
DilatedDepthSeparableConv,
|
DilatedDepthSeparableConv,
|
||||||
|
@ -180,7 +180,13 @@ def export_model_onnx(
|
|||||||
model_filename,
|
model_filename,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"],
|
input_names=[
|
||||||
|
"tokens",
|
||||||
|
"tokens_lens",
|
||||||
|
"noise_scale",
|
||||||
|
"noise_scale_dur",
|
||||||
|
"alpha",
|
||||||
|
],
|
||||||
output_names=["audio"],
|
output_names=["audio"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"tokens": {0: "N", 1: "T"},
|
"tokens": {0: "N", 1: "T"},
|
||||||
|
@ -13,7 +13,6 @@ import math
|
|||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transform import piecewise_rational_quadratic_transform
|
from transform import piecewise_rational_quadratic_transform
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,9 +16,6 @@ from typing import List, Optional, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
|
||||||
|
|
||||||
from duration_predictor import StochasticDurationPredictor
|
from duration_predictor import StochasticDurationPredictor
|
||||||
from hifigan import HiFiGANGenerator
|
from hifigan import HiFiGANGenerator
|
||||||
from posterior_encoder import PosteriorEncoder
|
from posterior_encoder import PosteriorEncoder
|
||||||
@ -26,6 +23,8 @@ from residual_coupling import ResidualAffineCouplingBlock
|
|||||||
from text_encoder import TextEncoder
|
from text_encoder import TextEncoder
|
||||||
from utils import get_random_segments
|
from utils import get_random_segments
|
||||||
|
|
||||||
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
class VITSGenerator(torch.nn.Module):
|
class VITSGenerator(torch.nn.Module):
|
||||||
"""Generator module in VITS, `Conditional Variational Autoencoder
|
"""Generator module in VITS, `Conditional Variational Autoencoder
|
||||||
|
@ -36,13 +36,12 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
from train import get_model, get_params
|
|
||||||
from tokenizer import Tokenizer
|
from tokenizer import Tokenizer
|
||||||
|
from train import get_model, get_params
|
||||||
|
from tts_datamodule import LJSpeechTtsDataModule
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.utils import AttributeDict, setup_logger
|
from icefall.utils import AttributeDict, setup_logger
|
||||||
from tts_datamodule import LJSpeechTtsDataModule
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -107,12 +106,12 @@ def infer_dataset(
|
|||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
torchaudio.save(
|
torchaudio.save(
|
||||||
str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
|
str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
|
||||||
audio[i:i + 1, :audio_lens[i]],
|
audio[i : i + 1, : audio_lens[i]],
|
||||||
sample_rate=params.sampling_rate,
|
sample_rate=params.sampling_rate,
|
||||||
)
|
)
|
||||||
torchaudio.save(
|
torchaudio.save(
|
||||||
str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
|
str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
|
||||||
audio_pred[i:i + 1, :audio_lens_pred[i]],
|
audio_pred[i : i + 1, : audio_lens_pred[i]],
|
||||||
sample_rate=params.sampling_rate,
|
sample_rate=params.sampling_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -144,14 +143,24 @@ def infer_dataset(
|
|||||||
audio_lens = batch["audio_lens"].tolist()
|
audio_lens = batch["audio_lens"].tolist()
|
||||||
cut_ids = [cut.id for cut in batch["cut"]]
|
cut_ids = [cut.id for cut in batch["cut"]]
|
||||||
|
|
||||||
audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens)
|
audio_pred, _, durations = model.inference_batch(
|
||||||
|
text=tokens, text_lengths=tokens_lens
|
||||||
|
)
|
||||||
audio_pred = audio_pred.detach().cpu()
|
audio_pred = audio_pred.detach().cpu()
|
||||||
# convert to samples
|
# convert to samples
|
||||||
audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
|
audio_lens_pred = (
|
||||||
|
(durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
|
||||||
|
)
|
||||||
|
|
||||||
futures.append(
|
futures.append(
|
||||||
executor.submit(
|
executor.submit(
|
||||||
_save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred
|
_save_worker,
|
||||||
|
batch_size,
|
||||||
|
cut_ids,
|
||||||
|
audio,
|
||||||
|
audio_pred,
|
||||||
|
audio_lens,
|
||||||
|
audio_lens_pred,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -160,7 +169,9 @@ def infer_dataset(
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
# return results
|
# return results
|
||||||
for f in futures:
|
for f in futures:
|
||||||
f.result()
|
f.result()
|
||||||
|
@ -14,7 +14,6 @@ from typing import List, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributions as D
|
import torch.distributions as D
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from lhotse.features.kaldi import Wav2LogFilterBank
|
from lhotse.features.kaldi import Wav2LogFilterBank
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,9 +12,9 @@ This code is based on https://github.com/jaywalnut310/vits.
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from wavenet import Conv1d, WaveNet
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
from wavenet import WaveNet, Conv1d
|
|
||||||
|
|
||||||
|
|
||||||
class PosteriorEncoder(torch.nn.Module):
|
class PosteriorEncoder(torch.nn.Module):
|
||||||
|
@ -12,7 +12,6 @@ This code is based on https://github.com/jaywalnut310/vits.
|
|||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from flow import FlipFlow
|
from flow import FlipFlow
|
||||||
from wavenet import WaveNet
|
from wavenet import WaveNet
|
||||||
|
|
||||||
|
@ -28,10 +28,10 @@ Use the onnx model to generate a wav:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
from tokenizer import Tokenizer
|
from tokenizer import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@ -169,9 +169,7 @@ class Transformer(nn.Module):
|
|||||||
x, pos_emb = self.encoder_pos(x)
|
x, pos_emb = self.encoder_pos(x)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
x = self.encoder(
|
x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C)
|
||||||
x, pos_emb, key_padding_mask=key_padding_mask
|
|
||||||
) # (T, N, C)
|
|
||||||
|
|
||||||
x = self.after_norm(x)
|
x = self.after_norm(x)
|
||||||
|
|
||||||
@ -207,7 +205,9 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
nn.Linear(dim_feedforward, d_model),
|
nn.Linear(dim_feedforward, d_model),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout)
|
self.self_attn = RelPositionMultiheadAttention(
|
||||||
|
d_model, num_heads, dropout=dropout
|
||||||
|
)
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||||
|
|
||||||
@ -242,7 +242,9 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
|
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
|
||||||
"""
|
"""
|
||||||
# macaron style feed-forward module
|
# macaron style feed-forward module
|
||||||
src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src)))
|
src = src + self.ff_scale * self.dropout(
|
||||||
|
self.feed_forward_macaron(self.norm_ff_macaron(src))
|
||||||
|
)
|
||||||
|
|
||||||
# multi-head self-attention module
|
# multi-head self-attention module
|
||||||
src_attn = self.self_attn(
|
src_attn = self.self_attn(
|
||||||
@ -490,11 +492,17 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
|
q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
|
||||||
k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
|
k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
|
||||||
v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
|
v = (
|
||||||
|
v.contiguous()
|
||||||
|
.view(seq_len, batch_size * self.num_heads, self.head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim)
|
q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim)
|
||||||
|
|
||||||
p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim)
|
p = self.linear_pos(pos_emb).view(
|
||||||
|
pos_emb.size(0), -1, self.num_heads, self.head_dim
|
||||||
|
)
|
||||||
# (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1)
|
# (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1)
|
||||||
p = p.permute(0, 2, 3, 1)
|
p = p.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
@ -506,15 +514,23 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
# first compute matrix a and matrix c
|
# first compute matrix a and matrix c
|
||||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||||
k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len)
|
k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len)
|
||||||
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len)
|
matrix_ac = torch.matmul(
|
||||||
|
q_with_bias_u, k
|
||||||
|
) # (batch_size, num_head, seq_len, seq_len)
|
||||||
|
|
||||||
# compute matrix b and matrix d
|
# compute matrix b and matrix d
|
||||||
matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1)
|
matrix_bd = torch.matmul(
|
||||||
matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len)
|
q_with_bias_v, p
|
||||||
|
) # (batch_size, num_head, seq_len, 2*seq_len-1)
|
||||||
|
matrix_bd = self.rel_shift(
|
||||||
|
matrix_bd
|
||||||
|
) # (batch_size, num_head, seq_len, seq_len)
|
||||||
|
|
||||||
# (batch_size, num_head, seq_len, seq_len)
|
# (batch_size, num_head, seq_len, seq_len)
|
||||||
attn_output_weights = (matrix_ac + matrix_bd) * scaling
|
attn_output_weights = (matrix_ac + matrix_bd) * scaling
|
||||||
attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len)
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
batch_size * self.num_heads, seq_len, seq_len
|
||||||
|
)
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
assert key_padding_mask.shape == (batch_size, seq_len)
|
assert key_padding_mask.shape == (batch_size, seq_len)
|
||||||
@ -536,10 +552,16 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
# (batch_size * num_head, seq_len, head_dim)
|
# (batch_size * num_head, seq_len, head_dim)
|
||||||
attn_output = torch.bmm(attn_output_weights, v)
|
attn_output = torch.bmm(attn_output_weights, v)
|
||||||
assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim)
|
assert attn_output.shape == (
|
||||||
|
batch_size * self.num_heads,
|
||||||
|
seq_len,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
attn_output = (
|
attn_output = (
|
||||||
attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim)
|
attn_output.transpose(0, 1)
|
||||||
|
.contiguous()
|
||||||
|
.view(seq_len, batch_size, self.embed_dim)
|
||||||
)
|
)
|
||||||
# (seq_len, batch_size, embed_dim)
|
# (seq_len, batch_size, embed_dim)
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
@ -78,7 +78,9 @@ class Tokenizer(object):
|
|||||||
|
|
||||||
return token_ids_list
|
return token_ids_list
|
||||||
|
|
||||||
def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True):
|
def tokens_to_token_ids(
|
||||||
|
self, tokens_list: List[str], intersperse_blank: bool = True
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
tokens_list:
|
tokens_list:
|
||||||
|
@ -18,21 +18,25 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch.optim import Optimizer
|
from tokenizer import Tokenizer
|
||||||
from torch.cuda.amp import GradScaler, autocast
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.optim import Optimizer
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tts_datamodule import LJSpeechTtsDataModule
|
||||||
|
from utils import MetricsTracker, plot_feature, save_checkpoint
|
||||||
|
from vits import VITS
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
@ -41,11 +45,6 @@ from icefall.env import get_env_info
|
|||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||||
|
|
||||||
from tokenizer import Tokenizer
|
|
||||||
from tts_datamodule import LJSpeechTtsDataModule
|
|
||||||
from utils import MetricsTracker, plot_feature, save_checkpoint
|
|
||||||
from vits import VITS
|
|
||||||
|
|
||||||
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
||||||
|
|
||||||
|
|
||||||
@ -385,11 +384,12 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
|
|
||||||
batch_size = len(batch["tokens"])
|
batch_size = len(batch["tokens"])
|
||||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
||||||
prepare_input(batch, tokenizer, device)
|
batch, tokenizer, device
|
||||||
|
)
|
||||||
|
|
||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
loss_info['samples'] = batch_size
|
loss_info["samples"] = batch_size
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
@ -446,7 +446,9 @@ def train_one_epoch(
|
|||||||
# behavior depending on the current grad scale.
|
# behavior depending on the current grad scale.
|
||||||
cur_grad_scale = scaler._scale.item()
|
cur_grad_scale = scaler._scale.item()
|
||||||
|
|
||||||
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0):
|
if cur_grad_scale < 8.0 or (
|
||||||
|
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
|
||||||
|
):
|
||||||
scaler.update(cur_grad_scale * 2.0)
|
scaler.update(cur_grad_scale * 2.0)
|
||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
if not saved_bad_model:
|
if not saved_bad_model:
|
||||||
@ -482,9 +484,7 @@ def train_one_epoch(
|
|||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
tb_writer, "train/tot_", params.batch_idx_train
|
|
||||||
)
|
|
||||||
if params.use_fp16:
|
if params.use_fp16:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||||
@ -492,19 +492,34 @@ def train_one_epoch(
|
|||||||
if "returned_sample" in stats_g:
|
if "returned_sample" in stats_g:
|
||||||
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
|
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
|
||||||
tb_writer.add_audio(
|
tb_writer.add_audio(
|
||||||
"train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate
|
"train/speech_hat_",
|
||||||
|
speech_hat_,
|
||||||
|
params.batch_idx_train,
|
||||||
|
params.sampling_rate,
|
||||||
)
|
)
|
||||||
tb_writer.add_audio(
|
tb_writer.add_audio(
|
||||||
"train/speech_", speech_, params.batch_idx_train, params.sampling_rate
|
"train/speech_",
|
||||||
|
speech_,
|
||||||
|
params.batch_idx_train,
|
||||||
|
params.sampling_rate,
|
||||||
)
|
)
|
||||||
tb_writer.add_image(
|
tb_writer.add_image(
|
||||||
"train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC'
|
"train/mel_hat_",
|
||||||
|
plot_feature(mel_hat_),
|
||||||
|
params.batch_idx_train,
|
||||||
|
dataformats="HWC",
|
||||||
)
|
)
|
||||||
tb_writer.add_image(
|
tb_writer.add_image(
|
||||||
"train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC'
|
"train/mel_",
|
||||||
|
plot_feature(mel_),
|
||||||
|
params.batch_idx_train,
|
||||||
|
dataformats="HWC",
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
|
if (
|
||||||
|
params.batch_idx_train % params.valid_interval == 0
|
||||||
|
and not params.print_diagnostics
|
||||||
|
):
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info, (speech_hat, speech) = compute_validation_loss(
|
valid_info, (speech_hat, speech) = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -523,10 +538,16 @@ def train_one_epoch(
|
|||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tb_writer.add_audio(
|
tb_writer.add_audio(
|
||||||
"train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate
|
"train/valdi_speech_hat",
|
||||||
|
speech_hat,
|
||||||
|
params.batch_idx_train,
|
||||||
|
params.sampling_rate,
|
||||||
)
|
)
|
||||||
tb_writer.add_audio(
|
tb_writer.add_audio(
|
||||||
"train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate
|
"train/valdi_speech",
|
||||||
|
speech,
|
||||||
|
params.batch_idx_train,
|
||||||
|
params.sampling_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||||
@ -555,11 +576,17 @@ def compute_validation_loss(
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
batch_size = len(batch["tokens"])
|
batch_size = len(batch["tokens"])
|
||||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
(
|
||||||
prepare_input(batch, tokenizer, device)
|
audio,
|
||||||
|
audio_lens,
|
||||||
|
features,
|
||||||
|
features_lens,
|
||||||
|
tokens,
|
||||||
|
tokens_lens,
|
||||||
|
) = prepare_input(batch, tokenizer, device)
|
||||||
|
|
||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
loss_info['samples'] = batch_size
|
loss_info["samples"] = batch_size
|
||||||
|
|
||||||
# forward discriminator
|
# forward discriminator
|
||||||
loss_d, stats_d = model(
|
loss_d, stats_d = model(
|
||||||
@ -596,12 +623,17 @@ def compute_validation_loss(
|
|||||||
if batch_idx == 0 and rank == 0:
|
if batch_idx == 0 and rank == 0:
|
||||||
inner_model = model.module if isinstance(model, DDP) else model
|
inner_model = model.module if isinstance(model, DDP) else model
|
||||||
audio_pred, _, duration = inner_model.inference(
|
audio_pred, _, duration = inner_model.inference(
|
||||||
text=tokens[0, :tokens_lens[0].item()]
|
text=tokens[0, : tokens_lens[0].item()]
|
||||||
)
|
)
|
||||||
audio_pred = audio_pred.data.cpu().numpy()
|
audio_pred = audio_pred.data.cpu().numpy()
|
||||||
audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
audio_len_pred = (
|
||||||
assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred))
|
(duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
||||||
audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy()
|
)
|
||||||
|
assert audio_len_pred == len(audio_pred), (
|
||||||
|
audio_len_pred,
|
||||||
|
len(audio_pred),
|
||||||
|
)
|
||||||
|
audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
|
||||||
returned_sample = (audio_pred, audio_gt)
|
returned_sample = (audio_pred, audio_gt)
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
@ -632,8 +664,9 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
||||||
prepare_input(batch, tokenizer, device)
|
batch, tokenizer, device
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
# for discriminator
|
# for discriminator
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
|
@ -29,10 +29,10 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
|||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
CutMix,
|
CutMix,
|
||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
SpeechSynthesisDataset,
|
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SimpleCutSampler,
|
SimpleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
|
SpeechSynthesisDataset,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
AudioSamples,
|
AudioSamples,
|
||||||
|
@ -14,15 +14,15 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from pathlib import Path
|
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
@ -97,23 +97,23 @@ def plot_feature(spectrogram):
|
|||||||
global MATPLOTLIB_FLAG
|
global MATPLOTLIB_FLAG
|
||||||
if not MATPLOTLIB_FLAG:
|
if not MATPLOTLIB_FLAG:
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
MATPLOTLIB_FLAG = True
|
MATPLOTLIB_FLAG = True
|
||||||
mpl_logger = logging.getLogger('matplotlib')
|
mpl_logger = logging.getLogger("matplotlib")
|
||||||
mpl_logger.setLevel(logging.WARNING)
|
mpl_logger.setLevel(logging.WARNING)
|
||||||
import matplotlib.pylab as plt
|
import matplotlib.pylab as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(10, 2))
|
fig, ax = plt.subplots(figsize=(10, 2))
|
||||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
||||||
interpolation='none')
|
|
||||||
plt.colorbar(im, ax=ax)
|
plt.colorbar(im, ax=ax)
|
||||||
plt.xlabel("Frames")
|
plt.xlabel("Frames")
|
||||||
plt.ylabel("Channels")
|
plt.ylabel("Channels")
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
||||||
fig.canvas.draw()
|
fig.canvas.draw()
|
||||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||||
plt.close()
|
plt.close()
|
||||||
return data
|
return data
|
||||||
|
@ -9,8 +9,7 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.cuda.amp import autocast
|
from generator import VITSGenerator
|
||||||
|
|
||||||
from hifigan import (
|
from hifigan import (
|
||||||
HiFiGANMultiPeriodDiscriminator,
|
HiFiGANMultiPeriodDiscriminator,
|
||||||
HiFiGANMultiScaleDiscriminator,
|
HiFiGANMultiScaleDiscriminator,
|
||||||
@ -25,9 +24,8 @@ from loss import (
|
|||||||
KLDivergenceLoss,
|
KLDivergenceLoss,
|
||||||
MelSpectrogramLoss,
|
MelSpectrogramLoss,
|
||||||
)
|
)
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
from utils import get_segments
|
from utils import get_segments
|
||||||
from generator import VITSGenerator
|
|
||||||
|
|
||||||
|
|
||||||
AVAILABLE_GENERATERS = {
|
AVAILABLE_GENERATERS = {
|
||||||
"vits_generator": VITSGenerator,
|
"vits_generator": VITSGenerator,
|
||||||
@ -42,8 +40,7 @@ AVAILABLE_DISCRIMINATORS = {
|
|||||||
|
|
||||||
|
|
||||||
class VITS(nn.Module):
|
class VITS(nn.Module):
|
||||||
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`
|
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -9,9 +9,8 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
Loading…
x
Reference in New Issue
Block a user