From a531c92711cbbddc9beaed8f4c23f53d5dc2f3df Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 29 Oct 2024 10:14:09 +0800 Subject: [PATCH] fix style issues --- .github/scripts/ljspeech/TTS/run-matcha.sh | 13 +++-- .github/workflows/ljspeech.yml | 6 --- egs/ljspeech/TTS/matcha/hifigan/denoiser.py | 13 +++-- egs/ljspeech/TTS/matcha/hifigan/meldataset.py | 50 ++++++++++++++---- egs/ljspeech/TTS/matcha/hifigan/models.py | 52 ++++++++++++++++--- 5 files changed, 100 insertions(+), 34 deletions(-) diff --git a/.github/scripts/ljspeech/TTS/run-matcha.sh b/.github/scripts/ljspeech/TTS/run-matcha.sh index b6eb81020..37e1bc320 100755 --- a/.github/scripts/ljspeech/TTS/run-matcha.sh +++ b/.github/scripts/ljspeech/TTS/run-matcha.sh @@ -101,13 +101,12 @@ function export_onnx() { ls -lh *.onnx - - python3 ./matcha/onnx_pretrained.py \ - --acoustic-model ./model-steps-6.onnx \ - --vocoder ./hifigan_v1.onnx \ - --tokens ./data/tokens.txt \ - --input-text "how are you doing?" \ - --output-wav /icefall/generated-matcha-tts-steps-6-v1.wav + python3 ./matcha/onnx_pretrained.py \ + --acoustic-model ./model-steps-6.onnx \ + --vocoder ./hifigan_v1.onnx \ + --tokens ./data/tokens.txt \ + --input-text "how are you doing?" \ + --output-wav /icefall/generated-matcha-tts-steps-6-v1.wav ls -lh /icefall/*.wav soxi /icefall/generated-matcha-tts-steps-6-v1.wav diff --git a/.github/workflows/ljspeech.yml b/.github/workflows/ljspeech.yml index 34a3797fa..7dca96b37 100644 --- a/.github/workflows/ljspeech.yml +++ b/.github/workflows/ljspeech.yml @@ -84,12 +84,6 @@ jobs: name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }} path: ./*.wav - - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' - with: - name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }} - path: ./*.wav - - name: Release exported onnx models if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' uses: svenstaro/upload-release-action@v2 diff --git a/egs/ljspeech/TTS/matcha/hifigan/denoiser.py b/egs/ljspeech/TTS/matcha/hifigan/denoiser.py index 9fd33312a..b9aea61b8 100644 --- a/egs/ljspeech/TTS/matcha/hifigan/denoiser.py +++ b/egs/ljspeech/TTS/matcha/hifigan/denoiser.py @@ -7,13 +7,18 @@ import torch class Denoiser(torch.nn.Module): """Removes model bias from audio produced with waveglow""" - def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"): + def __init__( + self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros" + ): super().__init__() self.filter_length = filter_length self.hop_length = int(filter_length / n_overlap) self.win_length = win_length - dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device + dtype, device = ( + next(vocoder.parameters()).dtype, + next(vocoder.parameters()).device, + ) self.device = device if mode == "zeros": mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device) @@ -32,7 +37,9 @@ class Denoiser(torch.nn.Module): return_complex=True, ) spec = torch.view_as_real(spec) - return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0]) + return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2( + spec[..., -1], spec[..., 0] + ) self.stft = lambda x: stft_fn( audio=x, diff --git a/egs/ljspeech/TTS/matcha/hifigan/meldataset.py b/egs/ljspeech/TTS/matcha/hifigan/meldataset.py index 8b43ea796..6eb15a326 100644 --- a/egs/ljspeech/TTS/matcha/hifigan/meldataset.py +++ b/egs/ljspeech/TTS/matcha/hifigan/meldataset.py @@ -49,7 +49,9 @@ mel_basis = {} hann_window = {} -def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): +def mel_spectrogram( + y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False +): if torch.min(y) < -1.0: print("min value is ", torch.min(y)) if torch.max(y) > 1.0: @@ -58,11 +60,15 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, global mel_basis, hann_window # pylint: disable=global-statement if fmax not in mel_basis: mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) - mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + mel_basis[str(fmax) + "_" + str(y.device)] = ( + torch.from_numpy(mel).float().to(y.device) + ) hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) y = torch.nn.functional.pad( - y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", ) y = y.squeeze(1) @@ -92,12 +98,16 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, def get_dataset_filelist(a): with open(a.input_training_file, encoding="utf-8") as fi: training_files = [ - os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") + for x in fi.read().split("\n") + if len(x) > 0 ] with open(a.input_validation_file, encoding="utf-8") as fi: validation_files = [ - os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") + for x in fi.read().split("\n") + if len(x) > 0 ] return training_files, validation_files @@ -152,7 +162,9 @@ class MelDataset(torch.utils.data.Dataset): audio = normalize(audio) * 0.95 self.cached_wav = audio if sampling_rate != self.sampling_rate: - raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR") + raise ValueError( + f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR" + ) self._cache_ref_count = self.n_cache_reuse else: audio = self.cached_wav @@ -168,7 +180,9 @@ class MelDataset(torch.utils.data.Dataset): audio_start = random.randint(0, max_audio_start) audio = audio[:, audio_start : audio_start + self.segment_size] else: - audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") + audio = torch.nn.functional.pad( + audio, (0, self.segment_size - audio.size(1)), "constant" + ) mel = mel_spectrogram( audio, @@ -182,7 +196,12 @@ class MelDataset(torch.utils.data.Dataset): center=False, ) else: - mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy")) + mel = np.load( + os.path.join( + self.base_mels_path, + os.path.splitext(os.path.split(filename)[-1])[0] + ".npy", + ) + ) mel = torch.from_numpy(mel) if len(mel.shape) < 3: @@ -194,10 +213,19 @@ class MelDataset(torch.utils.data.Dataset): if audio.size(1) >= self.segment_size: mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) mel = mel[:, :, mel_start : mel_start + frames_per_seg] - audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size] + audio = audio[ + :, + mel_start + * self.hop_size : (mel_start + frames_per_seg) + * self.hop_size, + ] else: - mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant") - audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") + mel = torch.nn.functional.pad( + mel, (0, frames_per_seg - mel.size(2)), "constant" + ) + audio = torch.nn.functional.pad( + audio, (0, self.segment_size - audio.size(1)), "constant" + ) mel_loss = mel_spectrogram( audio, diff --git a/egs/ljspeech/TTS/matcha/hifigan/models.py b/egs/ljspeech/TTS/matcha/hifigan/models.py index d209d9a4e..e6da20610 100644 --- a/egs/ljspeech/TTS/matcha/hifigan/models.py +++ b/egs/ljspeech/TTS/matcha/hifigan/models.py @@ -151,7 +151,9 @@ class Generator(torch.nn.Module): self.h = h self.num_kernels = len(h.resblock_kernel_sizes) self.num_upsamples = len(h.upsample_rates) - self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) + self.conv_pre = weight_norm( + Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3) + ) resblock = ResBlock1 if h.resblock == "1" else ResBlock2 self.ups = nn.ModuleList() @@ -171,7 +173,9 @@ class Generator(torch.nn.Module): self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = h.upsample_initial_channel // (2 ** (i + 1)) - for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + for _, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): self.resblocks.append(resblock(h, ch, k, d)) self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) @@ -213,10 +217,42 @@ class DiscriminatorP(torch.nn.Module): norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList( [ - norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), ] ) @@ -313,7 +349,9 @@ class MultiScaleDiscriminator(torch.nn.Module): DiscriminatorS(), ] ) - self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]) + self.meanpools = nn.ModuleList( + [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)] + ) def forward(self, y, y_hat): y_d_rs = []