mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
fix style issues
This commit is contained in:
parent
3a986335d7
commit
a531c92711
13
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
13
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
@ -101,13 +101,12 @@ function export_onnx() {
|
|||||||
|
|
||||||
ls -lh *.onnx
|
ls -lh *.onnx
|
||||||
|
|
||||||
|
python3 ./matcha/onnx_pretrained.py \
|
||||||
python3 ./matcha/onnx_pretrained.py \
|
--acoustic-model ./model-steps-6.onnx \
|
||||||
--acoustic-model ./model-steps-6.onnx \
|
--vocoder ./hifigan_v1.onnx \
|
||||||
--vocoder ./hifigan_v1.onnx \
|
--tokens ./data/tokens.txt \
|
||||||
--tokens ./data/tokens.txt \
|
--input-text "how are you doing?" \
|
||||||
--input-text "how are you doing?" \
|
--output-wav /icefall/generated-matcha-tts-steps-6-v1.wav
|
||||||
--output-wav /icefall/generated-matcha-tts-steps-6-v1.wav
|
|
||||||
|
|
||||||
ls -lh /icefall/*.wav
|
ls -lh /icefall/*.wav
|
||||||
soxi /icefall/generated-matcha-tts-steps-6-v1.wav
|
soxi /icefall/generated-matcha-tts-steps-6-v1.wav
|
||||||
|
6
.github/workflows/ljspeech.yml
vendored
6
.github/workflows/ljspeech.yml
vendored
@ -84,12 +84,6 @@ jobs:
|
|||||||
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
|
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
|
||||||
path: ./*.wav
|
path: ./*.wav
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
|
||||||
with:
|
|
||||||
name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}
|
|
||||||
path: ./*.wav
|
|
||||||
|
|
||||||
- name: Release exported onnx models
|
- name: Release exported onnx models
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||||
uses: svenstaro/upload-release-action@v2
|
uses: svenstaro/upload-release-action@v2
|
||||||
|
@ -7,13 +7,18 @@ import torch
|
|||||||
class Denoiser(torch.nn.Module):
|
class Denoiser(torch.nn.Module):
|
||||||
"""Removes model bias from audio produced with waveglow"""
|
"""Removes model bias from audio produced with waveglow"""
|
||||||
|
|
||||||
def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"):
|
def __init__(
|
||||||
|
self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.filter_length = filter_length
|
self.filter_length = filter_length
|
||||||
self.hop_length = int(filter_length / n_overlap)
|
self.hop_length = int(filter_length / n_overlap)
|
||||||
self.win_length = win_length
|
self.win_length = win_length
|
||||||
|
|
||||||
dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device
|
dtype, device = (
|
||||||
|
next(vocoder.parameters()).dtype,
|
||||||
|
next(vocoder.parameters()).device,
|
||||||
|
)
|
||||||
self.device = device
|
self.device = device
|
||||||
if mode == "zeros":
|
if mode == "zeros":
|
||||||
mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device)
|
mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device)
|
||||||
@ -32,7 +37,9 @@ class Denoiser(torch.nn.Module):
|
|||||||
return_complex=True,
|
return_complex=True,
|
||||||
)
|
)
|
||||||
spec = torch.view_as_real(spec)
|
spec = torch.view_as_real(spec)
|
||||||
return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0])
|
return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(
|
||||||
|
spec[..., -1], spec[..., 0]
|
||||||
|
)
|
||||||
|
|
||||||
self.stft = lambda x: stft_fn(
|
self.stft = lambda x: stft_fn(
|
||||||
audio=x,
|
audio=x,
|
||||||
|
@ -49,7 +49,9 @@ mel_basis = {}
|
|||||||
hann_window = {}
|
hann_window = {}
|
||||||
|
|
||||||
|
|
||||||
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
def mel_spectrogram(
|
||||||
|
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
||||||
|
):
|
||||||
if torch.min(y) < -1.0:
|
if torch.min(y) < -1.0:
|
||||||
print("min value is ", torch.min(y))
|
print("min value is ", torch.min(y))
|
||||||
if torch.max(y) > 1.0:
|
if torch.max(y) > 1.0:
|
||||||
@ -58,11 +60,15 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
|
|||||||
global mel_basis, hann_window # pylint: disable=global-statement
|
global mel_basis, hann_window # pylint: disable=global-statement
|
||||||
if fmax not in mel_basis:
|
if fmax not in mel_basis:
|
||||||
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
||||||
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
mel_basis[str(fmax) + "_" + str(y.device)] = (
|
||||||
|
torch.from_numpy(mel).float().to(y.device)
|
||||||
|
)
|
||||||
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
||||||
|
|
||||||
y = torch.nn.functional.pad(
|
y = torch.nn.functional.pad(
|
||||||
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
y.unsqueeze(1),
|
||||||
|
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||||
|
mode="reflect",
|
||||||
)
|
)
|
||||||
y = y.squeeze(1)
|
y = y.squeeze(1)
|
||||||
|
|
||||||
@ -92,12 +98,16 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
|
|||||||
def get_dataset_filelist(a):
|
def get_dataset_filelist(a):
|
||||||
with open(a.input_training_file, encoding="utf-8") as fi:
|
with open(a.input_training_file, encoding="utf-8") as fi:
|
||||||
training_files = [
|
training_files = [
|
||||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
||||||
|
for x in fi.read().split("\n")
|
||||||
|
if len(x) > 0
|
||||||
]
|
]
|
||||||
|
|
||||||
with open(a.input_validation_file, encoding="utf-8") as fi:
|
with open(a.input_validation_file, encoding="utf-8") as fi:
|
||||||
validation_files = [
|
validation_files = [
|
||||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
||||||
|
for x in fi.read().split("\n")
|
||||||
|
if len(x) > 0
|
||||||
]
|
]
|
||||||
return training_files, validation_files
|
return training_files, validation_files
|
||||||
|
|
||||||
@ -152,7 +162,9 @@ class MelDataset(torch.utils.data.Dataset):
|
|||||||
audio = normalize(audio) * 0.95
|
audio = normalize(audio) * 0.95
|
||||||
self.cached_wav = audio
|
self.cached_wav = audio
|
||||||
if sampling_rate != self.sampling_rate:
|
if sampling_rate != self.sampling_rate:
|
||||||
raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR")
|
raise ValueError(
|
||||||
|
f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR"
|
||||||
|
)
|
||||||
self._cache_ref_count = self.n_cache_reuse
|
self._cache_ref_count = self.n_cache_reuse
|
||||||
else:
|
else:
|
||||||
audio = self.cached_wav
|
audio = self.cached_wav
|
||||||
@ -168,7 +180,9 @@ class MelDataset(torch.utils.data.Dataset):
|
|||||||
audio_start = random.randint(0, max_audio_start)
|
audio_start = random.randint(0, max_audio_start)
|
||||||
audio = audio[:, audio_start : audio_start + self.segment_size]
|
audio = audio[:, audio_start : audio_start + self.segment_size]
|
||||||
else:
|
else:
|
||||||
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
|
audio = torch.nn.functional.pad(
|
||||||
|
audio, (0, self.segment_size - audio.size(1)), "constant"
|
||||||
|
)
|
||||||
|
|
||||||
mel = mel_spectrogram(
|
mel = mel_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
@ -182,7 +196,12 @@ class MelDataset(torch.utils.data.Dataset):
|
|||||||
center=False,
|
center=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy"))
|
mel = np.load(
|
||||||
|
os.path.join(
|
||||||
|
self.base_mels_path,
|
||||||
|
os.path.splitext(os.path.split(filename)[-1])[0] + ".npy",
|
||||||
|
)
|
||||||
|
)
|
||||||
mel = torch.from_numpy(mel)
|
mel = torch.from_numpy(mel)
|
||||||
|
|
||||||
if len(mel.shape) < 3:
|
if len(mel.shape) < 3:
|
||||||
@ -194,10 +213,19 @@ class MelDataset(torch.utils.data.Dataset):
|
|||||||
if audio.size(1) >= self.segment_size:
|
if audio.size(1) >= self.segment_size:
|
||||||
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
|
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
|
||||||
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
|
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
|
||||||
audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size]
|
audio = audio[
|
||||||
|
:,
|
||||||
|
mel_start
|
||||||
|
* self.hop_size : (mel_start + frames_per_seg)
|
||||||
|
* self.hop_size,
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
|
mel = torch.nn.functional.pad(
|
||||||
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
|
mel, (0, frames_per_seg - mel.size(2)), "constant"
|
||||||
|
)
|
||||||
|
audio = torch.nn.functional.pad(
|
||||||
|
audio, (0, self.segment_size - audio.size(1)), "constant"
|
||||||
|
)
|
||||||
|
|
||||||
mel_loss = mel_spectrogram(
|
mel_loss = mel_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
|
@ -151,7 +151,9 @@ class Generator(torch.nn.Module):
|
|||||||
self.h = h
|
self.h = h
|
||||||
self.num_kernels = len(h.resblock_kernel_sizes)
|
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||||
self.num_upsamples = len(h.upsample_rates)
|
self.num_upsamples = len(h.upsample_rates)
|
||||||
self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
|
self.conv_pre = weight_norm(
|
||||||
|
Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
|
||||||
|
)
|
||||||
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
|
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
|
||||||
|
|
||||||
self.ups = nn.ModuleList()
|
self.ups = nn.ModuleList()
|
||||||
@ -171,7 +173,9 @@ class Generator(torch.nn.Module):
|
|||||||
self.resblocks = nn.ModuleList()
|
self.resblocks = nn.ModuleList()
|
||||||
for i in range(len(self.ups)):
|
for i in range(len(self.ups)):
|
||||||
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
||||||
for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
for _, (k, d) in enumerate(
|
||||||
|
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
||||||
|
):
|
||||||
self.resblocks.append(resblock(h, ch, k, d))
|
self.resblocks.append(resblock(h, ch, k, d))
|
||||||
|
|
||||||
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
||||||
@ -213,10 +217,42 @@ class DiscriminatorP(torch.nn.Module):
|
|||||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||||
self.convs = nn.ModuleList(
|
self.convs = nn.ModuleList(
|
||||||
[
|
[
|
||||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
norm_f(
|
||||||
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
Conv2d(
|
||||||
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
1,
|
||||||
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
32,
|
||||||
|
(kernel_size, 1),
|
||||||
|
(stride, 1),
|
||||||
|
padding=(get_padding(5, 1), 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
norm_f(
|
||||||
|
Conv2d(
|
||||||
|
32,
|
||||||
|
128,
|
||||||
|
(kernel_size, 1),
|
||||||
|
(stride, 1),
|
||||||
|
padding=(get_padding(5, 1), 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
norm_f(
|
||||||
|
Conv2d(
|
||||||
|
128,
|
||||||
|
512,
|
||||||
|
(kernel_size, 1),
|
||||||
|
(stride, 1),
|
||||||
|
padding=(get_padding(5, 1), 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
norm_f(
|
||||||
|
Conv2d(
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
(kernel_size, 1),
|
||||||
|
(stride, 1),
|
||||||
|
padding=(get_padding(5, 1), 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -313,7 +349,9 @@ class MultiScaleDiscriminator(torch.nn.Module):
|
|||||||
DiscriminatorS(),
|
DiscriminatorS(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
|
self.meanpools = nn.ModuleList(
|
||||||
|
[AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, y, y_hat):
|
def forward(self, y, y_hat):
|
||||||
y_d_rs = []
|
y_d_rs = []
|
||||||
|
Loading…
x
Reference in New Issue
Block a user