minor updates

This commit is contained in:
zr_jin 2024-11-04 14:43:01 +08:00
parent e06ce7c63a
commit deff0140cd
6 changed files with 101 additions and 98 deletions

View File

@ -27,102 +27,17 @@ The generated fbank features are saved in data/fbank.
import argparse
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Union
import numpy as np
import torch
from audio import mel_spectrogram
from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
from lhotse.audio import RecordingSet
from lhotse.features.base import FeatureExtractor, register_extractor
from lhotse.supervision import SupervisionSet
from lhotse.utils import Seconds, compute_num_frames
from icefall.utils import get_executor
@dataclass
class MyFbankConfig:
n_fft: int
n_mels: int
sampling_rate: int
hop_length: int
win_length: int
f_min: float
f_max: float
@register_extractor
class MyFbank(FeatureExtractor):
name = "MyFbank"
config_type = MyFbankConfig
def __init__(self, config):
super().__init__(config=config)
@property
def device(self) -> Union[str, torch.device]:
return self.config.device
def feature_dim(self, sampling_rate: int) -> int:
return self.config.n_mels
def extract(
self,
samples: np.ndarray,
sampling_rate: int,
) -> torch.Tensor:
# Check for sampling rate compatibility.
expected_sr = self.config.sampling_rate
assert sampling_rate == expected_sr, (
f"Mismatched sampling rate: extractor expects {expected_sr}, "
f"got {sampling_rate}"
)
samples = torch.from_numpy(samples)
assert samples.ndim == 2, samples.shape
assert samples.shape[0] == 1, samples.shape
mel = (
mel_spectrogram(
samples,
self.config.n_fft,
self.config.n_mels,
self.config.sampling_rate,
self.config.hop_length,
self.config.win_length,
self.config.f_min,
self.config.f_max,
center=False,
)
.squeeze()
.t()
)
assert mel.ndim == 2, mel.shape
assert mel.shape[1] == self.config.n_mels, mel.shape
num_frames = compute_num_frames(
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
)
if mel.shape[0] > num_frames:
mel = mel[:num_frames]
elif mel.shape[0] < num_frames:
mel = mel.unsqueeze(0)
mel = torch.nn.functional.pad(
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
).squeeze(0)
return mel.numpy()
@property
def frame_shift(self) -> Seconds:
return self.config.hop_length / self.config.sampling_rate
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -149,7 +64,7 @@ def compute_fbank_ljspeech(num_jobs: int):
logging.info(f"num_jobs: {num_jobs}")
logging.info(f"src_dir: {src_dir}")
logging.info(f"output_dir: {output_dir}")
config = MyFbankConfig(
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=22050,
@ -170,7 +85,7 @@ def compute_fbank_ljspeech(num_jobs: int):
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
)
extractor = MyFbank(config)
extractor = MatchaFbank(config)
with get_executor() as ex: # Initialize the executor only once.
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"

View File

@ -0,0 +1 @@
../matcha/fbank.py

View File

@ -7,7 +7,7 @@ from typing import Any, Dict
import onnx
import torch
from inference import load_vocoder
from infer import load_vocoder
def add_meta_data(filename: str, meta_data: Dict[str, Any]):

View File

@ -0,0 +1,89 @@
from dataclasses import dataclass
from typing import Union
import numpy as np
import torch
from audio import mel_spectrogram
from lhotse.features.base import FeatureExtractor, register_extractor
from lhotse.utils import Seconds, compute_num_frames
@dataclass
class MatchaFbankConfig:
n_fft: int
n_mels: int
sampling_rate: int
hop_length: int
win_length: int
f_min: float
f_max: float
@register_extractor
class MatchaFbank(FeatureExtractor):
name = "MatchaFbank"
config_type = MatchaFbankConfig
def __init__(self, config):
super().__init__(config=config)
@property
def device(self) -> Union[str, torch.device]:
return self.config.device
def feature_dim(self, sampling_rate: int) -> int:
return self.config.n_mels
def extract(
self,
samples: np.ndarray,
sampling_rate: int,
) -> torch.Tensor:
# Check for sampling rate compatibility.
expected_sr = self.config.sampling_rate
assert sampling_rate == expected_sr, (
f"Mismatched sampling rate: extractor expects {expected_sr}, "
f"got {sampling_rate}"
)
samples = torch.from_numpy(samples)
assert samples.ndim == 2, samples.shape
assert samples.shape[0] == 1, samples.shape
mel = (
mel_spectrogram(
samples,
self.config.n_fft,
self.config.n_mels,
self.config.sampling_rate,
self.config.hop_length,
self.config.win_length,
self.config.f_min,
self.config.f_max,
center=False,
)
.squeeze()
.t()
)
assert mel.ndim == 2, mel.shape
assert mel.shape[1] == self.config.n_mels, mel.shape
num_frames = compute_num_frames(
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
)
if mel.shape[0] > num_frames:
mel = mel[:num_frames]
elif mel.shape[0] < num_frames:
mel = mel.unsqueeze(0)
mel = torch.nn.functional.pad(
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
).squeeze(0)
return mel.numpy()
@property
def frame_shift(self) -> Seconds:
return self.config.hop_length / self.config.sampling_rate

View File

@ -10,7 +10,6 @@ from shutil import copyfile
from typing import Any, Dict, Optional, Union
import k2
import numpy as np
import torch
import torch.multiprocessing as mp
import torch.nn as nn

View File

@ -24,7 +24,7 @@ from pathlib import Path
from typing import Any, Dict, Optional
import torch
from compute_fbank_ljspeech import MyFbank, MyFbankConfig
from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
@ -32,7 +32,6 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
@ -177,7 +176,7 @@ class LJSpeechTtsDataModule:
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MyFbankConfig(
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
@ -189,7 +188,7 @@ class LJSpeechTtsDataModule:
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
)
@ -238,7 +237,7 @@ class LJSpeechTtsDataModule:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MyFbankConfig(
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
@ -250,7 +249,7 @@ class LJSpeechTtsDataModule:
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
)
else:
@ -282,7 +281,7 @@ class LJSpeechTtsDataModule:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MyFbankConfig(
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
@ -294,7 +293,7 @@ class LJSpeechTtsDataModule:
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
)
else: