mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
minor updates
This commit is contained in:
parent
e06ce7c63a
commit
deff0140cd
@ -27,102 +27,17 @@ The generated fbank features are saved in data/fbank.
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from audio import mel_spectrogram
|
||||
from fbank import MatchaFbank, MatchaFbankConfig
|
||||
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
|
||||
from lhotse.audio import RecordingSet
|
||||
from lhotse.features.base import FeatureExtractor, register_extractor
|
||||
from lhotse.supervision import SupervisionSet
|
||||
from lhotse.utils import Seconds, compute_num_frames
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyFbankConfig:
|
||||
n_fft: int
|
||||
n_mels: int
|
||||
sampling_rate: int
|
||||
hop_length: int
|
||||
win_length: int
|
||||
f_min: float
|
||||
f_max: float
|
||||
|
||||
|
||||
@register_extractor
|
||||
class MyFbank(FeatureExtractor):
|
||||
|
||||
name = "MyFbank"
|
||||
config_type = MyFbankConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config=config)
|
||||
|
||||
@property
|
||||
def device(self) -> Union[str, torch.device]:
|
||||
return self.config.device
|
||||
|
||||
def feature_dim(self, sampling_rate: int) -> int:
|
||||
return self.config.n_mels
|
||||
|
||||
def extract(
|
||||
self,
|
||||
samples: np.ndarray,
|
||||
sampling_rate: int,
|
||||
) -> torch.Tensor:
|
||||
# Check for sampling rate compatibility.
|
||||
expected_sr = self.config.sampling_rate
|
||||
assert sampling_rate == expected_sr, (
|
||||
f"Mismatched sampling rate: extractor expects {expected_sr}, "
|
||||
f"got {sampling_rate}"
|
||||
)
|
||||
samples = torch.from_numpy(samples)
|
||||
assert samples.ndim == 2, samples.shape
|
||||
assert samples.shape[0] == 1, samples.shape
|
||||
|
||||
mel = (
|
||||
mel_spectrogram(
|
||||
samples,
|
||||
self.config.n_fft,
|
||||
self.config.n_mels,
|
||||
self.config.sampling_rate,
|
||||
self.config.hop_length,
|
||||
self.config.win_length,
|
||||
self.config.f_min,
|
||||
self.config.f_max,
|
||||
center=False,
|
||||
)
|
||||
.squeeze()
|
||||
.t()
|
||||
)
|
||||
|
||||
assert mel.ndim == 2, mel.shape
|
||||
assert mel.shape[1] == self.config.n_mels, mel.shape
|
||||
|
||||
num_frames = compute_num_frames(
|
||||
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
|
||||
)
|
||||
|
||||
if mel.shape[0] > num_frames:
|
||||
mel = mel[:num_frames]
|
||||
elif mel.shape[0] < num_frames:
|
||||
mel = mel.unsqueeze(0)
|
||||
mel = torch.nn.functional.pad(
|
||||
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
|
||||
).squeeze(0)
|
||||
|
||||
return mel.numpy()
|
||||
|
||||
@property
|
||||
def frame_shift(self) -> Seconds:
|
||||
return self.config.hop_length / self.config.sampling_rate
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -149,7 +64,7 @@ def compute_fbank_ljspeech(num_jobs: int):
|
||||
logging.info(f"num_jobs: {num_jobs}")
|
||||
logging.info(f"src_dir: {src_dir}")
|
||||
logging.info(f"output_dir: {output_dir}")
|
||||
config = MyFbankConfig(
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=22050,
|
||||
@ -170,7 +85,7 @@ def compute_fbank_ljspeech(num_jobs: int):
|
||||
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
|
||||
)
|
||||
|
||||
extractor = MyFbank(config)
|
||||
extractor = MatchaFbank(config)
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||
|
1
egs/ljspeech/TTS/local/fbank.py
Symbolic link
1
egs/ljspeech/TTS/local/fbank.py
Symbolic link
@ -0,0 +1 @@
|
||||
../matcha/fbank.py
|
@ -7,7 +7,7 @@ from typing import Any, Dict
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
from inference import load_vocoder
|
||||
from infer import load_vocoder
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
||||
|
89
egs/ljspeech/TTS/matcha/fbank.py
Normal file
89
egs/ljspeech/TTS/matcha/fbank.py
Normal file
@ -0,0 +1,89 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from audio import mel_spectrogram
|
||||
from lhotse.features.base import FeatureExtractor, register_extractor
|
||||
from lhotse.utils import Seconds, compute_num_frames
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchaFbankConfig:
|
||||
n_fft: int
|
||||
n_mels: int
|
||||
sampling_rate: int
|
||||
hop_length: int
|
||||
win_length: int
|
||||
f_min: float
|
||||
f_max: float
|
||||
|
||||
|
||||
@register_extractor
|
||||
class MatchaFbank(FeatureExtractor):
|
||||
|
||||
name = "MatchaFbank"
|
||||
config_type = MatchaFbankConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config=config)
|
||||
|
||||
@property
|
||||
def device(self) -> Union[str, torch.device]:
|
||||
return self.config.device
|
||||
|
||||
def feature_dim(self, sampling_rate: int) -> int:
|
||||
return self.config.n_mels
|
||||
|
||||
def extract(
|
||||
self,
|
||||
samples: np.ndarray,
|
||||
sampling_rate: int,
|
||||
) -> torch.Tensor:
|
||||
# Check for sampling rate compatibility.
|
||||
expected_sr = self.config.sampling_rate
|
||||
assert sampling_rate == expected_sr, (
|
||||
f"Mismatched sampling rate: extractor expects {expected_sr}, "
|
||||
f"got {sampling_rate}"
|
||||
)
|
||||
samples = torch.from_numpy(samples)
|
||||
assert samples.ndim == 2, samples.shape
|
||||
assert samples.shape[0] == 1, samples.shape
|
||||
|
||||
mel = (
|
||||
mel_spectrogram(
|
||||
samples,
|
||||
self.config.n_fft,
|
||||
self.config.n_mels,
|
||||
self.config.sampling_rate,
|
||||
self.config.hop_length,
|
||||
self.config.win_length,
|
||||
self.config.f_min,
|
||||
self.config.f_max,
|
||||
center=False,
|
||||
)
|
||||
.squeeze()
|
||||
.t()
|
||||
)
|
||||
|
||||
assert mel.ndim == 2, mel.shape
|
||||
assert mel.shape[1] == self.config.n_mels, mel.shape
|
||||
|
||||
num_frames = compute_num_frames(
|
||||
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
|
||||
)
|
||||
|
||||
if mel.shape[0] > num_frames:
|
||||
mel = mel[:num_frames]
|
||||
elif mel.shape[0] < num_frames:
|
||||
mel = mel.unsqueeze(0)
|
||||
mel = torch.nn.functional.pad(
|
||||
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
|
||||
).squeeze(0)
|
||||
|
||||
return mel.numpy()
|
||||
|
||||
@property
|
||||
def frame_shift(self) -> Seconds:
|
||||
return self.config.hop_length / self.config.sampling_rate
|
||||
|
@ -10,7 +10,6 @@ from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
@ -24,7 +24,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from compute_fbank_ljspeech import MyFbank, MyFbankConfig
|
||||
from fbank import MatchaFbank, MatchaFbankConfig
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
@ -32,7 +32,6 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
DynamicBucketingSampler,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
SpeechSynthesisDataset,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
@ -177,7 +176,7 @@ class LJSpeechTtsDataModule:
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = MyFbankConfig(
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
@ -189,7 +188,7 @@ class LJSpeechTtsDataModule:
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
|
||||
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
@ -238,7 +237,7 @@ class LJSpeechTtsDataModule:
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = MyFbankConfig(
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
@ -250,7 +249,7 @@ class LJSpeechTtsDataModule:
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
|
||||
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
@ -282,7 +281,7 @@ class LJSpeechTtsDataModule:
|
||||
logging.info("About to create test dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = MyFbankConfig(
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
@ -294,7 +293,7 @@ class LJSpeechTtsDataModule:
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
|
||||
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user