fix style issues

This commit is contained in:
Fangjun Kuang 2024-10-29 10:14:09 +08:00
parent 3a986335d7
commit a531c92711
5 changed files with 100 additions and 34 deletions

View File

@ -101,13 +101,12 @@ function export_onnx() {
ls -lh *.onnx ls -lh *.onnx
python3 ./matcha/onnx_pretrained.py \
python3 ./matcha/onnx_pretrained.py \ --acoustic-model ./model-steps-6.onnx \
--acoustic-model ./model-steps-6.onnx \ --vocoder ./hifigan_v1.onnx \
--vocoder ./hifigan_v1.onnx \ --tokens ./data/tokens.txt \
--tokens ./data/tokens.txt \ --input-text "how are you doing?" \
--input-text "how are you doing?" \ --output-wav /icefall/generated-matcha-tts-steps-6-v1.wav
--output-wav /icefall/generated-matcha-tts-steps-6-v1.wav
ls -lh /icefall/*.wav ls -lh /icefall/*.wav
soxi /icefall/generated-matcha-tts-steps-6-v1.wav soxi /icefall/generated-matcha-tts-steps-6-v1.wav

View File

@ -84,12 +84,6 @@ jobs:
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }} name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
path: ./*.wav path: ./*.wav
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with:
name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}
path: ./*.wav
- name: Release exported onnx models - name: Release exported onnx models
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
uses: svenstaro/upload-release-action@v2 uses: svenstaro/upload-release-action@v2

View File

@ -7,13 +7,18 @@ import torch
class Denoiser(torch.nn.Module): class Denoiser(torch.nn.Module):
"""Removes model bias from audio produced with waveglow""" """Removes model bias from audio produced with waveglow"""
def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"): def __init__(
self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"
):
super().__init__() super().__init__()
self.filter_length = filter_length self.filter_length = filter_length
self.hop_length = int(filter_length / n_overlap) self.hop_length = int(filter_length / n_overlap)
self.win_length = win_length self.win_length = win_length
dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device dtype, device = (
next(vocoder.parameters()).dtype,
next(vocoder.parameters()).device,
)
self.device = device self.device = device
if mode == "zeros": if mode == "zeros":
mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device) mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device)
@ -32,7 +37,9 @@ class Denoiser(torch.nn.Module):
return_complex=True, return_complex=True,
) )
spec = torch.view_as_real(spec) spec = torch.view_as_real(spec)
return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0]) return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(
spec[..., -1], spec[..., 0]
)
self.stft = lambda x: stft_fn( self.stft = lambda x: stft_fn(
audio=x, audio=x,

View File

@ -49,7 +49,9 @@ mel_basis = {}
hann_window = {} hann_window = {}
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): def mel_spectrogram(
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
):
if torch.min(y) < -1.0: if torch.min(y) < -1.0:
print("min value is ", torch.min(y)) print("min value is ", torch.min(y))
if torch.max(y) > 1.0: if torch.max(y) > 1.0:
@ -58,11 +60,15 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
global mel_basis, hann_window # pylint: disable=global-statement global mel_basis, hann_window # pylint: disable=global-statement
if fmax not in mel_basis: if fmax not in mel_basis:
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) mel_basis[str(fmax) + "_" + str(y.device)] = (
torch.from_numpy(mel).float().to(y.device)
)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad( y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
) )
y = y.squeeze(1) y = y.squeeze(1)
@ -92,12 +98,16 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
def get_dataset_filelist(a): def get_dataset_filelist(a):
with open(a.input_training_file, encoding="utf-8") as fi: with open(a.input_training_file, encoding="utf-8") as fi:
training_files = [ training_files = [
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
for x in fi.read().split("\n")
if len(x) > 0
] ]
with open(a.input_validation_file, encoding="utf-8") as fi: with open(a.input_validation_file, encoding="utf-8") as fi:
validation_files = [ validation_files = [
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
for x in fi.read().split("\n")
if len(x) > 0
] ]
return training_files, validation_files return training_files, validation_files
@ -152,7 +162,9 @@ class MelDataset(torch.utils.data.Dataset):
audio = normalize(audio) * 0.95 audio = normalize(audio) * 0.95
self.cached_wav = audio self.cached_wav = audio
if sampling_rate != self.sampling_rate: if sampling_rate != self.sampling_rate:
raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR") raise ValueError(
f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR"
)
self._cache_ref_count = self.n_cache_reuse self._cache_ref_count = self.n_cache_reuse
else: else:
audio = self.cached_wav audio = self.cached_wav
@ -168,7 +180,9 @@ class MelDataset(torch.utils.data.Dataset):
audio_start = random.randint(0, max_audio_start) audio_start = random.randint(0, max_audio_start)
audio = audio[:, audio_start : audio_start + self.segment_size] audio = audio[:, audio_start : audio_start + self.segment_size]
else: else:
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") audio = torch.nn.functional.pad(
audio, (0, self.segment_size - audio.size(1)), "constant"
)
mel = mel_spectrogram( mel = mel_spectrogram(
audio, audio,
@ -182,7 +196,12 @@ class MelDataset(torch.utils.data.Dataset):
center=False, center=False,
) )
else: else:
mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy")) mel = np.load(
os.path.join(
self.base_mels_path,
os.path.splitext(os.path.split(filename)[-1])[0] + ".npy",
)
)
mel = torch.from_numpy(mel) mel = torch.from_numpy(mel)
if len(mel.shape) < 3: if len(mel.shape) < 3:
@ -194,10 +213,19 @@ class MelDataset(torch.utils.data.Dataset):
if audio.size(1) >= self.segment_size: if audio.size(1) >= self.segment_size:
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
mel = mel[:, :, mel_start : mel_start + frames_per_seg] mel = mel[:, :, mel_start : mel_start + frames_per_seg]
audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size] audio = audio[
:,
mel_start
* self.hop_size : (mel_start + frames_per_seg)
* self.hop_size,
]
else: else:
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant") mel = torch.nn.functional.pad(
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") mel, (0, frames_per_seg - mel.size(2)), "constant"
)
audio = torch.nn.functional.pad(
audio, (0, self.segment_size - audio.size(1)), "constant"
)
mel_loss = mel_spectrogram( mel_loss = mel_spectrogram(
audio, audio,

View File

@ -151,7 +151,9 @@ class Generator(torch.nn.Module):
self.h = h self.h = h
self.num_kernels = len(h.resblock_kernel_sizes) self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates) self.num_upsamples = len(h.upsample_rates)
self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) self.conv_pre = weight_norm(
Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
)
resblock = ResBlock1 if h.resblock == "1" else ResBlock2 resblock = ResBlock1 if h.resblock == "1" else ResBlock2
self.ups = nn.ModuleList() self.ups = nn.ModuleList()
@ -171,7 +173,9 @@ class Generator(torch.nn.Module):
self.resblocks = nn.ModuleList() self.resblocks = nn.ModuleList()
for i in range(len(self.ups)): for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1)) ch = h.upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): for _, (k, d) in enumerate(
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
):
self.resblocks.append(resblock(h, ch, k, d)) self.resblocks.append(resblock(h, ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
@ -213,10 +217,42 @@ class DiscriminatorP(torch.nn.Module):
norm_f = weight_norm if use_spectral_norm is False else spectral_norm norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList( self.convs = nn.ModuleList(
[ [
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), Conv2d(
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 1,
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 32,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(5, 1), 0),
)
),
norm_f(
Conv2d(
32,
128,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(5, 1), 0),
)
),
norm_f(
Conv2d(
128,
512,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(5, 1), 0),
)
),
norm_f(
Conv2d(
512,
1024,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(5, 1), 0),
)
),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
] ]
) )
@ -313,7 +349,9 @@ class MultiScaleDiscriminator(torch.nn.Module):
DiscriminatorS(), DiscriminatorS(),
] ]
) )
self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]) self.meanpools = nn.ModuleList(
[AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
)
def forward(self, y, y_hat): def forward(self, y, y_hat):
y_d_rs = [] y_d_rs = []