mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
fix style issues
This commit is contained in:
parent
3a986335d7
commit
a531c92711
1
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
1
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
@ -101,7 +101,6 @@ function export_onnx() {
|
||||
|
||||
ls -lh *.onnx
|
||||
|
||||
|
||||
python3 ./matcha/onnx_pretrained.py \
|
||||
--acoustic-model ./model-steps-6.onnx \
|
||||
--vocoder ./hifigan_v1.onnx \
|
||||
|
6
.github/workflows/ljspeech.yml
vendored
6
.github/workflows/ljspeech.yml
vendored
@ -84,12 +84,6 @@ jobs:
|
||||
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
|
||||
path: ./*.wav
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
||||
with:
|
||||
name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}
|
||||
path: ./*.wav
|
||||
|
||||
- name: Release exported onnx models
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
|
@ -7,13 +7,18 @@ import torch
|
||||
class Denoiser(torch.nn.Module):
|
||||
"""Removes model bias from audio produced with waveglow"""
|
||||
|
||||
def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"):
|
||||
def __init__(
|
||||
self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"
|
||||
):
|
||||
super().__init__()
|
||||
self.filter_length = filter_length
|
||||
self.hop_length = int(filter_length / n_overlap)
|
||||
self.win_length = win_length
|
||||
|
||||
dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device
|
||||
dtype, device = (
|
||||
next(vocoder.parameters()).dtype,
|
||||
next(vocoder.parameters()).device,
|
||||
)
|
||||
self.device = device
|
||||
if mode == "zeros":
|
||||
mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device)
|
||||
@ -32,7 +37,9 @@ class Denoiser(torch.nn.Module):
|
||||
return_complex=True,
|
||||
)
|
||||
spec = torch.view_as_real(spec)
|
||||
return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0])
|
||||
return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(
|
||||
spec[..., -1], spec[..., 0]
|
||||
)
|
||||
|
||||
self.stft = lambda x: stft_fn(
|
||||
audio=x,
|
||||
|
@ -49,7 +49,9 @@ mel_basis = {}
|
||||
hann_window = {}
|
||||
|
||||
|
||||
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
def mel_spectrogram(
|
||||
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
||||
):
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
@ -58,11 +60,15 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
|
||||
global mel_basis, hann_window # pylint: disable=global-statement
|
||||
if fmax not in mel_basis:
|
||||
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
||||
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
||||
mel_basis[str(fmax) + "_" + str(y.device)] = (
|
||||
torch.from_numpy(mel).float().to(y.device)
|
||||
)
|
||||
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
@ -92,12 +98,16 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
|
||||
def get_dataset_filelist(a):
|
||||
with open(a.input_training_file, encoding="utf-8") as fi:
|
||||
training_files = [
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
]
|
||||
|
||||
with open(a.input_validation_file, encoding="utf-8") as fi:
|
||||
validation_files = [
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
]
|
||||
return training_files, validation_files
|
||||
|
||||
@ -152,7 +162,9 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
audio = normalize(audio) * 0.95
|
||||
self.cached_wav = audio
|
||||
if sampling_rate != self.sampling_rate:
|
||||
raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR")
|
||||
raise ValueError(
|
||||
f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR"
|
||||
)
|
||||
self._cache_ref_count = self.n_cache_reuse
|
||||
else:
|
||||
audio = self.cached_wav
|
||||
@ -168,7 +180,9 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
audio_start = random.randint(0, max_audio_start)
|
||||
audio = audio[:, audio_start : audio_start + self.segment_size]
|
||||
else:
|
||||
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
|
||||
audio = torch.nn.functional.pad(
|
||||
audio, (0, self.segment_size - audio.size(1)), "constant"
|
||||
)
|
||||
|
||||
mel = mel_spectrogram(
|
||||
audio,
|
||||
@ -182,7 +196,12 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
center=False,
|
||||
)
|
||||
else:
|
||||
mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy"))
|
||||
mel = np.load(
|
||||
os.path.join(
|
||||
self.base_mels_path,
|
||||
os.path.splitext(os.path.split(filename)[-1])[0] + ".npy",
|
||||
)
|
||||
)
|
||||
mel = torch.from_numpy(mel)
|
||||
|
||||
if len(mel.shape) < 3:
|
||||
@ -194,10 +213,19 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
if audio.size(1) >= self.segment_size:
|
||||
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
|
||||
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
|
||||
audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size]
|
||||
audio = audio[
|
||||
:,
|
||||
mel_start
|
||||
* self.hop_size : (mel_start + frames_per_seg)
|
||||
* self.hop_size,
|
||||
]
|
||||
else:
|
||||
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
|
||||
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
|
||||
mel = torch.nn.functional.pad(
|
||||
mel, (0, frames_per_seg - mel.size(2)), "constant"
|
||||
)
|
||||
audio = torch.nn.functional.pad(
|
||||
audio, (0, self.segment_size - audio.size(1)), "constant"
|
||||
)
|
||||
|
||||
mel_loss = mel_spectrogram(
|
||||
audio,
|
||||
|
@ -151,7 +151,9 @@ class Generator(torch.nn.Module):
|
||||
self.h = h
|
||||
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
|
||||
self.conv_pre = weight_norm(
|
||||
Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
|
||||
)
|
||||
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
@ -171,7 +173,9 @@ class Generator(torch.nn.Module):
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||
for _, (k, d) in enumerate(
|
||||
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
||||
):
|
||||
self.resblocks.append(resblock(h, ch, k, d))
|
||||
|
||||
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
||||
@ -213,10 +217,42 @@ class DiscriminatorP(torch.nn.Module):
|
||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
1,
|
||||
32,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
32,
|
||||
128,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
128,
|
||||
512,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
512,
|
||||
1024,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||
]
|
||||
)
|
||||
@ -313,7 +349,9 @@ class MultiScaleDiscriminator(torch.nn.Module):
|
||||
DiscriminatorS(),
|
||||
]
|
||||
)
|
||||
self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
|
||||
self.meanpools = nn.ModuleList(
|
||||
[AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
|
||||
)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user