mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
minor updates
This commit is contained in:
parent
e06ce7c63a
commit
deff0140cd
@ -27,102 +27,17 @@ The generated fbank features are saved in data/fbank.
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from audio import mel_spectrogram
|
from fbank import MatchaFbank, MatchaFbankConfig
|
||||||
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
|
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
|
||||||
from lhotse.audio import RecordingSet
|
from lhotse.audio import RecordingSet
|
||||||
from lhotse.features.base import FeatureExtractor, register_extractor
|
|
||||||
from lhotse.supervision import SupervisionSet
|
from lhotse.supervision import SupervisionSet
|
||||||
from lhotse.utils import Seconds, compute_num_frames
|
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MyFbankConfig:
|
|
||||||
n_fft: int
|
|
||||||
n_mels: int
|
|
||||||
sampling_rate: int
|
|
||||||
hop_length: int
|
|
||||||
win_length: int
|
|
||||||
f_min: float
|
|
||||||
f_max: float
|
|
||||||
|
|
||||||
|
|
||||||
@register_extractor
|
|
||||||
class MyFbank(FeatureExtractor):
|
|
||||||
|
|
||||||
name = "MyFbank"
|
|
||||||
config_type = MyFbankConfig
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__(config=config)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self) -> Union[str, torch.device]:
|
|
||||||
return self.config.device
|
|
||||||
|
|
||||||
def feature_dim(self, sampling_rate: int) -> int:
|
|
||||||
return self.config.n_mels
|
|
||||||
|
|
||||||
def extract(
|
|
||||||
self,
|
|
||||||
samples: np.ndarray,
|
|
||||||
sampling_rate: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# Check for sampling rate compatibility.
|
|
||||||
expected_sr = self.config.sampling_rate
|
|
||||||
assert sampling_rate == expected_sr, (
|
|
||||||
f"Mismatched sampling rate: extractor expects {expected_sr}, "
|
|
||||||
f"got {sampling_rate}"
|
|
||||||
)
|
|
||||||
samples = torch.from_numpy(samples)
|
|
||||||
assert samples.ndim == 2, samples.shape
|
|
||||||
assert samples.shape[0] == 1, samples.shape
|
|
||||||
|
|
||||||
mel = (
|
|
||||||
mel_spectrogram(
|
|
||||||
samples,
|
|
||||||
self.config.n_fft,
|
|
||||||
self.config.n_mels,
|
|
||||||
self.config.sampling_rate,
|
|
||||||
self.config.hop_length,
|
|
||||||
self.config.win_length,
|
|
||||||
self.config.f_min,
|
|
||||||
self.config.f_max,
|
|
||||||
center=False,
|
|
||||||
)
|
|
||||||
.squeeze()
|
|
||||||
.t()
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mel.ndim == 2, mel.shape
|
|
||||||
assert mel.shape[1] == self.config.n_mels, mel.shape
|
|
||||||
|
|
||||||
num_frames = compute_num_frames(
|
|
||||||
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
if mel.shape[0] > num_frames:
|
|
||||||
mel = mel[:num_frames]
|
|
||||||
elif mel.shape[0] < num_frames:
|
|
||||||
mel = mel.unsqueeze(0)
|
|
||||||
mel = torch.nn.functional.pad(
|
|
||||||
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
|
|
||||||
).squeeze(0)
|
|
||||||
|
|
||||||
return mel.numpy()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def frame_shift(self) -> Seconds:
|
|
||||||
return self.config.hop_length / self.config.sampling_rate
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -149,7 +64,7 @@ def compute_fbank_ljspeech(num_jobs: int):
|
|||||||
logging.info(f"num_jobs: {num_jobs}")
|
logging.info(f"num_jobs: {num_jobs}")
|
||||||
logging.info(f"src_dir: {src_dir}")
|
logging.info(f"src_dir: {src_dir}")
|
||||||
logging.info(f"output_dir: {output_dir}")
|
logging.info(f"output_dir: {output_dir}")
|
||||||
config = MyFbankConfig(
|
config = MatchaFbankConfig(
|
||||||
n_fft=1024,
|
n_fft=1024,
|
||||||
n_mels=80,
|
n_mels=80,
|
||||||
sampling_rate=22050,
|
sampling_rate=22050,
|
||||||
@ -170,7 +85,7 @@ def compute_fbank_ljspeech(num_jobs: int):
|
|||||||
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
|
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
|
||||||
)
|
)
|
||||||
|
|
||||||
extractor = MyFbank(config)
|
extractor = MatchaFbank(config)
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||||
|
1
egs/ljspeech/TTS/local/fbank.py
Symbolic link
1
egs/ljspeech/TTS/local/fbank.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../matcha/fbank.py
|
@ -7,7 +7,7 @@ from typing import Any, Dict
|
|||||||
|
|
||||||
import onnx
|
import onnx
|
||||||
import torch
|
import torch
|
||||||
from inference import load_vocoder
|
from infer import load_vocoder
|
||||||
|
|
||||||
|
|
||||||
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
||||||
|
89
egs/ljspeech/TTS/matcha/fbank.py
Normal file
89
egs/ljspeech/TTS/matcha/fbank.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from audio import mel_spectrogram
|
||||||
|
from lhotse.features.base import FeatureExtractor, register_extractor
|
||||||
|
from lhotse.utils import Seconds, compute_num_frames
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MatchaFbankConfig:
|
||||||
|
n_fft: int
|
||||||
|
n_mels: int
|
||||||
|
sampling_rate: int
|
||||||
|
hop_length: int
|
||||||
|
win_length: int
|
||||||
|
f_min: float
|
||||||
|
f_max: float
|
||||||
|
|
||||||
|
|
||||||
|
@register_extractor
|
||||||
|
class MatchaFbank(FeatureExtractor):
|
||||||
|
|
||||||
|
name = "MatchaFbank"
|
||||||
|
config_type = MatchaFbankConfig
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config=config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Union[str, torch.device]:
|
||||||
|
return self.config.device
|
||||||
|
|
||||||
|
def feature_dim(self, sampling_rate: int) -> int:
|
||||||
|
return self.config.n_mels
|
||||||
|
|
||||||
|
def extract(
|
||||||
|
self,
|
||||||
|
samples: np.ndarray,
|
||||||
|
sampling_rate: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Check for sampling rate compatibility.
|
||||||
|
expected_sr = self.config.sampling_rate
|
||||||
|
assert sampling_rate == expected_sr, (
|
||||||
|
f"Mismatched sampling rate: extractor expects {expected_sr}, "
|
||||||
|
f"got {sampling_rate}"
|
||||||
|
)
|
||||||
|
samples = torch.from_numpy(samples)
|
||||||
|
assert samples.ndim == 2, samples.shape
|
||||||
|
assert samples.shape[0] == 1, samples.shape
|
||||||
|
|
||||||
|
mel = (
|
||||||
|
mel_spectrogram(
|
||||||
|
samples,
|
||||||
|
self.config.n_fft,
|
||||||
|
self.config.n_mels,
|
||||||
|
self.config.sampling_rate,
|
||||||
|
self.config.hop_length,
|
||||||
|
self.config.win_length,
|
||||||
|
self.config.f_min,
|
||||||
|
self.config.f_max,
|
||||||
|
center=False,
|
||||||
|
)
|
||||||
|
.squeeze()
|
||||||
|
.t()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mel.ndim == 2, mel.shape
|
||||||
|
assert mel.shape[1] == self.config.n_mels, mel.shape
|
||||||
|
|
||||||
|
num_frames = compute_num_frames(
|
||||||
|
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
if mel.shape[0] > num_frames:
|
||||||
|
mel = mel[:num_frames]
|
||||||
|
elif mel.shape[0] < num_frames:
|
||||||
|
mel = mel.unsqueeze(0)
|
||||||
|
mel = torch.nn.functional.pad(
|
||||||
|
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
return mel.numpy()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def frame_shift(self) -> Seconds:
|
||||||
|
return self.config.hop_length / self.config.sampling_rate
|
||||||
|
|
@ -10,7 +10,6 @@ from shutil import copyfile
|
|||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -24,7 +24,7 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compute_fbank_ljspeech import MyFbank, MyFbankConfig
|
from fbank import MatchaFbank, MatchaFbankConfig
|
||||||
from lhotse import CutSet, load_manifest_lazy
|
from lhotse import CutSet, load_manifest_lazy
|
||||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
@ -32,7 +32,6 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
|||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SimpleCutSampler,
|
SimpleCutSampler,
|
||||||
SpecAugment,
|
|
||||||
SpeechSynthesisDataset,
|
SpeechSynthesisDataset,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
@ -177,7 +176,7 @@ class LJSpeechTtsDataModule:
|
|||||||
|
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
sampling_rate = 22050
|
sampling_rate = 22050
|
||||||
config = MyFbankConfig(
|
config = MatchaFbankConfig(
|
||||||
n_fft=1024,
|
n_fft=1024,
|
||||||
n_mels=80,
|
n_mels=80,
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
@ -189,7 +188,7 @@ class LJSpeechTtsDataModule:
|
|||||||
train = SpeechSynthesisDataset(
|
train = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=False,
|
||||||
return_tokens=True,
|
return_tokens=True,
|
||||||
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
|
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -238,7 +237,7 @@ class LJSpeechTtsDataModule:
|
|||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
sampling_rate = 22050
|
sampling_rate = 22050
|
||||||
config = MyFbankConfig(
|
config = MatchaFbankConfig(
|
||||||
n_fft=1024,
|
n_fft=1024,
|
||||||
n_mels=80,
|
n_mels=80,
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
@ -250,7 +249,7 @@ class LJSpeechTtsDataModule:
|
|||||||
validate = SpeechSynthesisDataset(
|
validate = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=False,
|
||||||
return_tokens=True,
|
return_tokens=True,
|
||||||
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
|
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -282,7 +281,7 @@ class LJSpeechTtsDataModule:
|
|||||||
logging.info("About to create test dataset")
|
logging.info("About to create test dataset")
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
sampling_rate = 22050
|
sampling_rate = 22050
|
||||||
config = MyFbankConfig(
|
config = MatchaFbankConfig(
|
||||||
n_fft=1024,
|
n_fft=1024,
|
||||||
n_mels=80,
|
n_mels=80,
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
@ -294,7 +293,7 @@ class LJSpeechTtsDataModule:
|
|||||||
test = SpeechSynthesisDataset(
|
test = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=False,
|
||||||
return_tokens=True,
|
return_tokens=True,
|
||||||
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
|
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user