From deff0140cdcc6939eec6ddaea16b52022742cd79 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 4 Nov 2024 14:43:01 +0800 Subject: [PATCH] minor updates --- .../TTS/local/compute_fbank_ljspeech.py | 91 +------------------ egs/ljspeech/TTS/local/fbank.py | 1 + .../TTS/matcha/export_onnx_hifigan.py | 2 +- egs/ljspeech/TTS/matcha/fbank.py | 89 ++++++++++++++++++ egs/ljspeech/TTS/matcha/train.py | 1 - egs/ljspeech/TTS/matcha/tts_datamodule.py | 15 ++- 6 files changed, 101 insertions(+), 98 deletions(-) create mode 120000 egs/ljspeech/TTS/local/fbank.py create mode 100644 egs/ljspeech/TTS/matcha/fbank.py diff --git a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py index 69f572ae1..296f9a4f4 100755 --- a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py +++ b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py @@ -27,102 +27,17 @@ The generated fbank features are saved in data/fbank. import argparse import logging import os -from dataclasses import dataclass from pathlib import Path -from typing import Union -import numpy as np import torch -from audio import mel_spectrogram +from fbank import MatchaFbank, MatchaFbankConfig from lhotse import CutSet, LilcomChunkyWriter, load_manifest from lhotse.audio import RecordingSet -from lhotse.features.base import FeatureExtractor, register_extractor from lhotse.supervision import SupervisionSet -from lhotse.utils import Seconds, compute_num_frames from icefall.utils import get_executor -@dataclass -class MyFbankConfig: - n_fft: int - n_mels: int - sampling_rate: int - hop_length: int - win_length: int - f_min: float - f_max: float - - -@register_extractor -class MyFbank(FeatureExtractor): - - name = "MyFbank" - config_type = MyFbankConfig - - def __init__(self, config): - super().__init__(config=config) - - @property - def device(self) -> Union[str, torch.device]: - return self.config.device - - def feature_dim(self, sampling_rate: int) -> int: - return self.config.n_mels - - def extract( - self, - samples: np.ndarray, - sampling_rate: int, - ) -> torch.Tensor: - # Check for sampling rate compatibility. - expected_sr = self.config.sampling_rate - assert sampling_rate == expected_sr, ( - f"Mismatched sampling rate: extractor expects {expected_sr}, " - f"got {sampling_rate}" - ) - samples = torch.from_numpy(samples) - assert samples.ndim == 2, samples.shape - assert samples.shape[0] == 1, samples.shape - - mel = ( - mel_spectrogram( - samples, - self.config.n_fft, - self.config.n_mels, - self.config.sampling_rate, - self.config.hop_length, - self.config.win_length, - self.config.f_min, - self.config.f_max, - center=False, - ) - .squeeze() - .t() - ) - - assert mel.ndim == 2, mel.shape - assert mel.shape[1] == self.config.n_mels, mel.shape - - num_frames = compute_num_frames( - samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate - ) - - if mel.shape[0] > num_frames: - mel = mel[:num_frames] - elif mel.shape[0] < num_frames: - mel = mel.unsqueeze(0) - mel = torch.nn.functional.pad( - mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate" - ).squeeze(0) - - return mel.numpy() - - @property - def frame_shift(self) -> Seconds: - return self.config.hop_length / self.config.sampling_rate - - def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -149,7 +64,7 @@ def compute_fbank_ljspeech(num_jobs: int): logging.info(f"num_jobs: {num_jobs}") logging.info(f"src_dir: {src_dir}") logging.info(f"output_dir: {output_dir}") - config = MyFbankConfig( + config = MatchaFbankConfig( n_fft=1024, n_mels=80, sampling_rate=22050, @@ -170,7 +85,7 @@ def compute_fbank_ljspeech(num_jobs: int): src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet ) - extractor = MyFbank(config) + extractor = MatchaFbank(config) with get_executor() as ex: # Initialize the executor only once. cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" diff --git a/egs/ljspeech/TTS/local/fbank.py b/egs/ljspeech/TTS/local/fbank.py new file mode 120000 index 000000000..5bcf1fde5 --- /dev/null +++ b/egs/ljspeech/TTS/local/fbank.py @@ -0,0 +1 @@ +../matcha/fbank.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py index 63d1fac20..5c96b3bc7 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py +++ b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py @@ -7,7 +7,7 @@ from typing import Any, Dict import onnx import torch -from inference import load_vocoder +from infer import load_vocoder def add_meta_data(filename: str, meta_data: Dict[str, Any]): diff --git a/egs/ljspeech/TTS/matcha/fbank.py b/egs/ljspeech/TTS/matcha/fbank.py new file mode 100644 index 000000000..e6c07f0ea --- /dev/null +++ b/egs/ljspeech/TTS/matcha/fbank.py @@ -0,0 +1,89 @@ +from dataclasses import dataclass +from typing import Union + +import numpy as np +import torch +from audio import mel_spectrogram +from lhotse.features.base import FeatureExtractor, register_extractor +from lhotse.utils import Seconds, compute_num_frames + + +@dataclass +class MatchaFbankConfig: + n_fft: int + n_mels: int + sampling_rate: int + hop_length: int + win_length: int + f_min: float + f_max: float + + +@register_extractor +class MatchaFbank(FeatureExtractor): + + name = "MatchaFbank" + config_type = MatchaFbankConfig + + def __init__(self, config): + super().__init__(config=config) + + @property + def device(self) -> Union[str, torch.device]: + return self.config.device + + def feature_dim(self, sampling_rate: int) -> int: + return self.config.n_mels + + def extract( + self, + samples: np.ndarray, + sampling_rate: int, + ) -> torch.Tensor: + # Check for sampling rate compatibility. + expected_sr = self.config.sampling_rate + assert sampling_rate == expected_sr, ( + f"Mismatched sampling rate: extractor expects {expected_sr}, " + f"got {sampling_rate}" + ) + samples = torch.from_numpy(samples) + assert samples.ndim == 2, samples.shape + assert samples.shape[0] == 1, samples.shape + + mel = ( + mel_spectrogram( + samples, + self.config.n_fft, + self.config.n_mels, + self.config.sampling_rate, + self.config.hop_length, + self.config.win_length, + self.config.f_min, + self.config.f_max, + center=False, + ) + .squeeze() + .t() + ) + + assert mel.ndim == 2, mel.shape + assert mel.shape[1] == self.config.n_mels, mel.shape + + num_frames = compute_num_frames( + samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate + ) + + if mel.shape[0] > num_frames: + mel = mel[:num_frames] + elif mel.shape[0] < num_frames: + mel = mel.unsqueeze(0) + mel = torch.nn.functional.pad( + mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate" + ).squeeze(0) + + return mel.numpy() + + @property + def frame_shift(self) -> Seconds: + return self.config.hop_length / self.config.sampling_rate + diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 8ad307fda..78f4f3373 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -10,7 +10,6 @@ from shutil import copyfile from typing import Any, Dict, Optional, Union import k2 -import numpy as np import torch import torch.multiprocessing as mp import torch.nn as nn diff --git a/egs/ljspeech/TTS/matcha/tts_datamodule.py b/egs/ljspeech/TTS/matcha/tts_datamodule.py index 8e37fc030..1e637b766 100644 --- a/egs/ljspeech/TTS/matcha/tts_datamodule.py +++ b/egs/ljspeech/TTS/matcha/tts_datamodule.py @@ -24,7 +24,7 @@ from pathlib import Path from typing import Any, Dict, Optional import torch -from compute_fbank_ljspeech import MyFbank, MyFbankConfig +from fbank import MatchaFbank, MatchaFbankConfig from lhotse import CutSet, load_manifest_lazy from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, @@ -32,7 +32,6 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, PrecomputedFeatures, SimpleCutSampler, - SpecAugment, SpeechSynthesisDataset, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples @@ -177,7 +176,7 @@ class LJSpeechTtsDataModule: if self.args.on_the_fly_feats: sampling_rate = 22050 - config = MyFbankConfig( + config = MatchaFbankConfig( n_fft=1024, n_mels=80, sampling_rate=sampling_rate, @@ -189,7 +188,7 @@ class LJSpeechTtsDataModule: train = SpeechSynthesisDataset( return_text=False, return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), return_cuts=self.args.return_cuts, ) @@ -238,7 +237,7 @@ class LJSpeechTtsDataModule: logging.info("About to create dev dataset") if self.args.on_the_fly_feats: sampling_rate = 22050 - config = MyFbankConfig( + config = MatchaFbankConfig( n_fft=1024, n_mels=80, sampling_rate=sampling_rate, @@ -250,7 +249,7 @@ class LJSpeechTtsDataModule: validate = SpeechSynthesisDataset( return_text=False, return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), return_cuts=self.args.return_cuts, ) else: @@ -282,7 +281,7 @@ class LJSpeechTtsDataModule: logging.info("About to create test dataset") if self.args.on_the_fly_feats: sampling_rate = 22050 - config = MyFbankConfig( + config = MatchaFbankConfig( n_fft=1024, n_mels=80, sampling_rate=sampling_rate, @@ -294,7 +293,7 @@ class LJSpeechTtsDataModule: test = SpeechSynthesisDataset( return_text=False, return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), return_cuts=self.args.return_cuts, ) else: