minor updates

This commit is contained in:
zr_jin 2024-11-04 14:43:01 +08:00
parent e06ce7c63a
commit deff0140cd
6 changed files with 101 additions and 98 deletions

View File

@ -27,102 +27,17 @@ The generated fbank features are saved in data/fbank.
import argparse import argparse
import logging import logging
import os import os
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Union
import numpy as np
import torch import torch
from audio import mel_spectrogram from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, LilcomChunkyWriter, load_manifest from lhotse import CutSet, LilcomChunkyWriter, load_manifest
from lhotse.audio import RecordingSet from lhotse.audio import RecordingSet
from lhotse.features.base import FeatureExtractor, register_extractor
from lhotse.supervision import SupervisionSet from lhotse.supervision import SupervisionSet
from lhotse.utils import Seconds, compute_num_frames
from icefall.utils import get_executor from icefall.utils import get_executor
@dataclass
class MyFbankConfig:
n_fft: int
n_mels: int
sampling_rate: int
hop_length: int
win_length: int
f_min: float
f_max: float
@register_extractor
class MyFbank(FeatureExtractor):
name = "MyFbank"
config_type = MyFbankConfig
def __init__(self, config):
super().__init__(config=config)
@property
def device(self) -> Union[str, torch.device]:
return self.config.device
def feature_dim(self, sampling_rate: int) -> int:
return self.config.n_mels
def extract(
self,
samples: np.ndarray,
sampling_rate: int,
) -> torch.Tensor:
# Check for sampling rate compatibility.
expected_sr = self.config.sampling_rate
assert sampling_rate == expected_sr, (
f"Mismatched sampling rate: extractor expects {expected_sr}, "
f"got {sampling_rate}"
)
samples = torch.from_numpy(samples)
assert samples.ndim == 2, samples.shape
assert samples.shape[0] == 1, samples.shape
mel = (
mel_spectrogram(
samples,
self.config.n_fft,
self.config.n_mels,
self.config.sampling_rate,
self.config.hop_length,
self.config.win_length,
self.config.f_min,
self.config.f_max,
center=False,
)
.squeeze()
.t()
)
assert mel.ndim == 2, mel.shape
assert mel.shape[1] == self.config.n_mels, mel.shape
num_frames = compute_num_frames(
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
)
if mel.shape[0] > num_frames:
mel = mel[:num_frames]
elif mel.shape[0] < num_frames:
mel = mel.unsqueeze(0)
mel = torch.nn.functional.pad(
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
).squeeze(0)
return mel.numpy()
@property
def frame_shift(self) -> Seconds:
return self.config.hop_length / self.config.sampling_rate
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -149,7 +64,7 @@ def compute_fbank_ljspeech(num_jobs: int):
logging.info(f"num_jobs: {num_jobs}") logging.info(f"num_jobs: {num_jobs}")
logging.info(f"src_dir: {src_dir}") logging.info(f"src_dir: {src_dir}")
logging.info(f"output_dir: {output_dir}") logging.info(f"output_dir: {output_dir}")
config = MyFbankConfig( config = MatchaFbankConfig(
n_fft=1024, n_fft=1024,
n_mels=80, n_mels=80,
sampling_rate=22050, sampling_rate=22050,
@ -170,7 +85,7 @@ def compute_fbank_ljspeech(num_jobs: int):
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
) )
extractor = MyFbank(config) extractor = MatchaFbank(config)
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"

View File

@ -0,0 +1 @@
../matcha/fbank.py

View File

@ -7,7 +7,7 @@ from typing import Any, Dict
import onnx import onnx
import torch import torch
from inference import load_vocoder from infer import load_vocoder
def add_meta_data(filename: str, meta_data: Dict[str, Any]): def add_meta_data(filename: str, meta_data: Dict[str, Any]):

View File

@ -0,0 +1,89 @@
from dataclasses import dataclass
from typing import Union
import numpy as np
import torch
from audio import mel_spectrogram
from lhotse.features.base import FeatureExtractor, register_extractor
from lhotse.utils import Seconds, compute_num_frames
@dataclass
class MatchaFbankConfig:
n_fft: int
n_mels: int
sampling_rate: int
hop_length: int
win_length: int
f_min: float
f_max: float
@register_extractor
class MatchaFbank(FeatureExtractor):
name = "MatchaFbank"
config_type = MatchaFbankConfig
def __init__(self, config):
super().__init__(config=config)
@property
def device(self) -> Union[str, torch.device]:
return self.config.device
def feature_dim(self, sampling_rate: int) -> int:
return self.config.n_mels
def extract(
self,
samples: np.ndarray,
sampling_rate: int,
) -> torch.Tensor:
# Check for sampling rate compatibility.
expected_sr = self.config.sampling_rate
assert sampling_rate == expected_sr, (
f"Mismatched sampling rate: extractor expects {expected_sr}, "
f"got {sampling_rate}"
)
samples = torch.from_numpy(samples)
assert samples.ndim == 2, samples.shape
assert samples.shape[0] == 1, samples.shape
mel = (
mel_spectrogram(
samples,
self.config.n_fft,
self.config.n_mels,
self.config.sampling_rate,
self.config.hop_length,
self.config.win_length,
self.config.f_min,
self.config.f_max,
center=False,
)
.squeeze()
.t()
)
assert mel.ndim == 2, mel.shape
assert mel.shape[1] == self.config.n_mels, mel.shape
num_frames = compute_num_frames(
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
)
if mel.shape[0] > num_frames:
mel = mel[:num_frames]
elif mel.shape[0] < num_frames:
mel = mel.unsqueeze(0)
mel = torch.nn.functional.pad(
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
).squeeze(0)
return mel.numpy()
@property
def frame_shift(self) -> Seconds:
return self.config.hop_length / self.config.sampling_rate

View File

@ -10,7 +10,6 @@ from shutil import copyfile
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import k2 import k2
import numpy as np
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn

View File

@ -24,7 +24,7 @@ from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import torch import torch
from compute_fbank_ljspeech import MyFbank, MyFbankConfig from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, load_manifest_lazy from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate, CutConcatenate,
@ -32,7 +32,6 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
DynamicBucketingSampler, DynamicBucketingSampler,
PrecomputedFeatures, PrecomputedFeatures,
SimpleCutSampler, SimpleCutSampler,
SpecAugment,
SpeechSynthesisDataset, SpeechSynthesisDataset,
) )
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
@ -177,7 +176,7 @@ class LJSpeechTtsDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
sampling_rate = 22050 sampling_rate = 22050
config = MyFbankConfig( config = MatchaFbankConfig(
n_fft=1024, n_fft=1024,
n_mels=80, n_mels=80,
sampling_rate=sampling_rate, sampling_rate=sampling_rate,
@ -189,7 +188,7 @@ class LJSpeechTtsDataModule:
train = SpeechSynthesisDataset( train = SpeechSynthesisDataset(
return_text=False, return_text=False,
return_tokens=True, return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -238,7 +237,7 @@ class LJSpeechTtsDataModule:
logging.info("About to create dev dataset") logging.info("About to create dev dataset")
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
sampling_rate = 22050 sampling_rate = 22050
config = MyFbankConfig( config = MatchaFbankConfig(
n_fft=1024, n_fft=1024,
n_mels=80, n_mels=80,
sampling_rate=sampling_rate, sampling_rate=sampling_rate,
@ -250,7 +249,7 @@ class LJSpeechTtsDataModule:
validate = SpeechSynthesisDataset( validate = SpeechSynthesisDataset(
return_text=False, return_text=False,
return_tokens=True, return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:
@ -282,7 +281,7 @@ class LJSpeechTtsDataModule:
logging.info("About to create test dataset") logging.info("About to create test dataset")
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
sampling_rate = 22050 sampling_rate = 22050
config = MyFbankConfig( config = MatchaFbankConfig(
n_fft=1024, n_fft=1024,
n_mels=80, n_mels=80,
sampling_rate=sampling_rate, sampling_rate=sampling_rate,
@ -294,7 +293,7 @@ class LJSpeechTtsDataModule:
test = SpeechSynthesisDataset( test = SpeechSynthesisDataset(
return_text=False, return_text=False,
return_tokens=True, return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else: