mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-10 02:22:16 +00:00
wrap FbankOptions to Python.
This commit is contained in:
parent
53504a705c
commit
9a5567e21b
@ -15,15 +15,14 @@ template <class F>
|
||||
torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
|
||||
float vtln_warp) {
|
||||
KALDIFEAT_ASSERT(wave.dim() == 1);
|
||||
int32_t rows_out = NumFrames(wave.sizes()[0], computer_.GetFrameOptions());
|
||||
int32_t cols_out = computer_.Dim();
|
||||
|
||||
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
|
||||
|
||||
torch::Tensor strided_input = GetStrided(wave, frame_opts);
|
||||
|
||||
if (frame_opts.dither != 0)
|
||||
if (frame_opts.dither != 0.0f) {
|
||||
strided_input = Dither(strided_input, frame_opts.dither);
|
||||
}
|
||||
|
||||
if (frame_opts.remove_dc_offset) {
|
||||
torch::Tensor row_means = strided_input.mean(1).unsqueeze(1);
|
||||
@ -37,12 +36,14 @@ torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
|
||||
constexpr float kEps = 1.1920928955078125e-07f;
|
||||
|
||||
if (use_raw_log_energy) {
|
||||
// it is true iff use_energy==true and row_energy==true
|
||||
log_energy_pre_window =
|
||||
torch::clamp_min(strided_input.pow(2).sum(1), kEps).log();
|
||||
}
|
||||
|
||||
if (frame_opts.preemph_coeff != 0.0f)
|
||||
if (frame_opts.preemph_coeff != 0.0f) {
|
||||
Preemphasize(frame_opts.preemph_coeff, &strided_input);
|
||||
}
|
||||
|
||||
feature_window_function_.Apply(&strided_input);
|
||||
|
||||
|
@ -13,6 +13,11 @@
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, const FbankOptions &opts) {
|
||||
os << opts.ToString();
|
||||
return os;
|
||||
}
|
||||
|
||||
FbankComputer::FbankComputer(const FbankOptions &opts) : opts_(opts) {
|
||||
if (opts.energy_floor > 0.0f) log_energy_floor_ = logf(opts.energy_floor);
|
||||
|
||||
@ -78,11 +83,6 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
|
||||
// Use power instead of magnitude if requested.
|
||||
if (opts_.use_power) spectrum.pow_(2);
|
||||
|
||||
#if 0
|
||||
int32_t mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0);
|
||||
SubVector<float> mel_energies(*feature, mel_offset, opts_.mel_opts.num_bins);
|
||||
#endif
|
||||
|
||||
// TODO(fangjun): remove the last column of spectrum
|
||||
torch::Tensor mel_energies = mel_banks.Compute(spectrum);
|
||||
if (opts_.use_log_fbank) {
|
||||
@ -90,17 +90,26 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
|
||||
mel_energies = torch::clamp_min(mel_energies, kEps).log();
|
||||
}
|
||||
|
||||
// if use_energy is true, then we get an extra bin. That is,
|
||||
// if num_mel_bins is 23, the feature will contain 24 bins.
|
||||
//
|
||||
// if htk_compat is false, then the 0th bin is the log energy
|
||||
// if htk_compat is true, then the last bin is the log energy
|
||||
|
||||
// Copy energy as first value (or the last, if htk_compat == true).
|
||||
if (opts_.use_energy) {
|
||||
#if 0
|
||||
if (opts_.energy_floor > 0.0 && signal_raw_log_energy < log_energy_floor_) {
|
||||
signal_raw_log_energy = log_energy_floor_;
|
||||
if (opts_.energy_floor > 0.0f) {
|
||||
signal_raw_log_energy =
|
||||
torch::clamp_min(signal_raw_log_energy, log_energy_floor_);
|
||||
}
|
||||
#endif
|
||||
int32_t energy_index = opts_.htk_compat ? opts_.mel_opts.num_bins : 0;
|
||||
energy_index = 0; // TODO(fangjun): fix it
|
||||
|
||||
mel_energies.index({"...", energy_index}) = signal_raw_log_energy;
|
||||
signal_raw_log_energy.unsqueeze_(1);
|
||||
|
||||
if (opts_.htk_compat) {
|
||||
mel_energies = torch::cat({mel_energies, signal_raw_log_energy}, 1);
|
||||
} else {
|
||||
mel_energies = torch::cat({signal_raw_log_energy, mel_energies}, 1);
|
||||
}
|
||||
}
|
||||
|
||||
return mel_energies;
|
||||
|
@ -20,12 +20,16 @@ struct FbankOptions {
|
||||
MelBanksOptions mel_opts;
|
||||
// append an extra dimension with energy to the filter banks
|
||||
bool use_energy = false;
|
||||
float energy_floor = 0.0f;
|
||||
float energy_floor = 0.0f; // active iff use_energy==true
|
||||
|
||||
// If true, compute log_energy before preemphasis and windowing
|
||||
// If false, compute log_energy after preemphasis ans windowing
|
||||
bool raw_energy = true; // active iff use_energy==true
|
||||
|
||||
// If true, compute energy before preemphasis and windowing
|
||||
bool raw_energy = true;
|
||||
// If true, put energy last (if using energy)
|
||||
bool htk_compat = false;
|
||||
// If false, put energy first
|
||||
bool htk_compat = false; // active iff use_energy==true
|
||||
|
||||
// if true (default), produce log-filterbank, else linear
|
||||
bool use_log_fbank = true;
|
||||
|
||||
@ -34,8 +38,28 @@ struct FbankOptions {
|
||||
bool use_power = true;
|
||||
|
||||
FbankOptions() { mel_opts.num_bins = 23; }
|
||||
|
||||
std::string ToString() const {
|
||||
std::ostringstream os;
|
||||
os << "frame_opts: \n";
|
||||
os << frame_opts << "\n";
|
||||
os << "\n";
|
||||
|
||||
os << "mel_opts: \n";
|
||||
os << mel_opts << "\n";
|
||||
|
||||
os << "use_energy: " << use_energy << "\n";
|
||||
os << "energy_floor: " << energy_floor << "\n";
|
||||
os << "raw_energy: " << raw_energy << "\n";
|
||||
os << "htk_compat: " << htk_compat << "\n";
|
||||
os << "use_log_fbank: " << use_log_fbank << "\n";
|
||||
os << "use_power: " << use_power << "\n";
|
||||
return os.str();
|
||||
}
|
||||
};
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, const FbankOptions &opts);
|
||||
|
||||
class FbankComputer {
|
||||
public:
|
||||
using Options = FbankOptions;
|
||||
@ -51,12 +75,15 @@ class FbankComputer {
|
||||
return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0);
|
||||
}
|
||||
|
||||
// if true, compute log_energy_pre_window but after dithering and dc removal
|
||||
bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; }
|
||||
|
||||
const FrameExtractionOptions &GetFrameOptions() const {
|
||||
return opts_.frame_opts;
|
||||
}
|
||||
|
||||
// signal_raw_log_energy is log_energy_pre_window, which is not empty
|
||||
// iff NeedRawLogEnergy() returns true.
|
||||
torch::Tensor Compute(torch::Tensor signal_raw_log_energy, float vtln_warp,
|
||||
const torch::Tensor &signal_frame);
|
||||
|
||||
|
@ -16,6 +16,11 @@
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts) {
|
||||
os << opts.ToString();
|
||||
return os;
|
||||
}
|
||||
|
||||
FeatureWindowFunction::FeatureWindowFunction(
|
||||
const FrameExtractionOptions &opts) {
|
||||
int32_t frame_length = opts.WindowSize();
|
||||
|
@ -39,9 +39,9 @@ struct FrameExtractionOptions {
|
||||
bool round_to_power_of_two = true;
|
||||
float blackman_coeff = 0.42f;
|
||||
bool snip_edges = true;
|
||||
bool allow_downsample = false;
|
||||
bool allow_upsample = false;
|
||||
int32_t max_feature_vectors = -1;
|
||||
// bool allow_downsample = false;
|
||||
// bool allow_upsample = false;
|
||||
// int32_t max_feature_vectors = -1;
|
||||
|
||||
int32_t WindowShift() const {
|
||||
return static_cast<int32_t>(samp_freq * 0.001f * frame_shift_ms);
|
||||
@ -53,8 +53,29 @@ struct FrameExtractionOptions {
|
||||
return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(WindowSize())
|
||||
: WindowSize());
|
||||
}
|
||||
std::string ToString() const {
|
||||
std::ostringstream os;
|
||||
#define KALDIFEAT_PRINT(x) os << #x << ": " << x << "\n"
|
||||
KALDIFEAT_PRINT(samp_freq);
|
||||
KALDIFEAT_PRINT(frame_shift_ms);
|
||||
KALDIFEAT_PRINT(frame_length_ms);
|
||||
KALDIFEAT_PRINT(dither);
|
||||
KALDIFEAT_PRINT(preemph_coeff);
|
||||
KALDIFEAT_PRINT(remove_dc_offset);
|
||||
KALDIFEAT_PRINT(window_type);
|
||||
KALDIFEAT_PRINT(round_to_power_of_two);
|
||||
KALDIFEAT_PRINT(blackman_coeff);
|
||||
KALDIFEAT_PRINT(snip_edges);
|
||||
// KALDIFEAT_PRINT(allow_downsample);
|
||||
// KALDIFEAT_PRINT(allow_upsample);
|
||||
// KALDIFEAT_PRINT(max_feature_vectors);
|
||||
#undef KALDIFEAT_PRINT
|
||||
return os.str();
|
||||
}
|
||||
};
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts);
|
||||
|
||||
class FeatureWindowFunction {
|
||||
public:
|
||||
FeatureWindowFunction() = default;
|
||||
|
@ -10,6 +10,11 @@
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts) {
|
||||
os << opts.ToString();
|
||||
return os;
|
||||
}
|
||||
|
||||
float MelBanks::VtlnWarpFreq(
|
||||
float vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN.
|
||||
float vtln_high_cutoff,
|
||||
|
@ -32,8 +32,22 @@ struct MelBanksOptions {
|
||||
// Enables more exact compatibility with HTK, for testing purposes. Affects
|
||||
// mel-energy flooring and reproduces a bug in HTK.
|
||||
bool htk_mode = false;
|
||||
|
||||
std::string ToString() const {
|
||||
std::ostringstream os;
|
||||
os << "num_bins: " << num_bins << "\n";
|
||||
os << "low_freq: " << low_freq << "\n";
|
||||
os << "high_freq: " << high_freq << "\n";
|
||||
os << "vtln_low: " << vtln_low << "\n";
|
||||
os << "vtln_high: " << vtln_high << "\n";
|
||||
os << "debug_mel: " << debug_mel << "\n";
|
||||
os << "htk_mode: " << htk_mode << "\n";
|
||||
return os.str();
|
||||
}
|
||||
};
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts);
|
||||
|
||||
class MelBanks {
|
||||
public:
|
||||
static inline float InverseMelScale(float mel_freq) {
|
||||
|
@ -62,8 +62,18 @@ static void TestDither() {
|
||||
std::cout << (a + b * 2) << "\n";
|
||||
}
|
||||
|
||||
static void TestCat() {
|
||||
torch::Tensor a = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat);
|
||||
torch::Tensor b = torch::arange(0, 2).reshape({2, 1}).to(torch::kFloat) * 0.1;
|
||||
torch::Tensor c = torch::cat({a, b}, 1);
|
||||
torch::Tensor d = torch::cat({b, a}, 1);
|
||||
std::cout << a << "\n";
|
||||
std::cout << b << "\n";
|
||||
std::cout << c << "\n";
|
||||
std::cout << d << "\n";
|
||||
}
|
||||
|
||||
int main() {
|
||||
// TestDither();
|
||||
TestGetStrided();
|
||||
TestCat();
|
||||
return 0;
|
||||
}
|
||||
|
@ -1,4 +1,9 @@
|
||||
add_definitions(-DTORCH_API_INCLUDE_EXTENSION_H)
|
||||
pybind11_add_module(_kaldifeat kaldifeat.cc)
|
||||
pybind11_add_module(_kaldifeat
|
||||
feature-fbank.cc
|
||||
feature-window.cc
|
||||
kaldifeat.cc
|
||||
mel-computations.cc
|
||||
)
|
||||
target_link_libraries(_kaldifeat PRIVATE kaldifeat_core)
|
||||
target_link_libraries(_kaldifeat PRIVATE ${TORCH_DIR}/lib/libtorch_python.so)
|
||||
|
27
kaldifeat/python/csrc/feature-fbank.cc
Normal file
27
kaldifeat/python/csrc/feature-fbank.cc
Normal file
@ -0,0 +1,27 @@
|
||||
// kaldifeat/python/csrc/feature-fbank.cc
|
||||
//
|
||||
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
#include "kaldifeat/python/csrc/feature-fbank.h"
|
||||
|
||||
#include "kaldifeat/csrc/feature-fbank.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
void PybindFbankOptions(py::module &m) {
|
||||
py::class_<FbankOptions>(m, "FbankOptions")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("frame_opts", &FbankOptions::frame_opts)
|
||||
.def_readwrite("mel_opts", &FbankOptions::mel_opts)
|
||||
.def_readwrite("use_energy", &FbankOptions::use_energy)
|
||||
.def_readwrite("energy_floor", &FbankOptions::energy_floor)
|
||||
.def_readwrite("raw_energy", &FbankOptions::raw_energy)
|
||||
.def_readwrite("htk_compat", &FbankOptions::htk_compat)
|
||||
.def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank)
|
||||
.def_readwrite("use_power", &FbankOptions::use_power)
|
||||
.def("__str__", [](const FbankOptions &self) -> std::string {
|
||||
return self.ToString();
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
16
kaldifeat/python/csrc/feature-fbank.h
Normal file
16
kaldifeat/python/csrc/feature-fbank.h
Normal file
@ -0,0 +1,16 @@
|
||||
// kaldifeat/python/csrc/feature-fbank.h
|
||||
//
|
||||
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
#ifndef KALDIFEAT_PYTHON_CSRC_FEATURE_FBANK_H_
|
||||
#define KALDIFEAT_PYTHON_CSRC_FEATURE_FBANK_H_
|
||||
|
||||
#include "kaldifeat/python/csrc/kaldifeat.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
void PybindFbankOptions(py::module &m);
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
#endif // KALDIFEAT_PYTHON_CSRC_FEATURE_FBANK_H_
|
39
kaldifeat/python/csrc/feature-window.cc
Normal file
39
kaldifeat/python/csrc/feature-window.cc
Normal file
@ -0,0 +1,39 @@
|
||||
// kaldifeat/python/csrc/feature-window.cc
|
||||
//
|
||||
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
#include "kaldifeat/python/csrc/feature-window.h"
|
||||
|
||||
#include "kaldifeat/csrc/feature-window.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
void PybindFrameExtractionOptions(py::module &m) {
|
||||
py::class_<FrameExtractionOptions>(m, "FrameExtractionOptions")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("samp_freq", &FrameExtractionOptions::samp_freq)
|
||||
.def_readwrite("frame_shift_ms", &FrameExtractionOptions::frame_shift_ms)
|
||||
.def_readwrite("frame_length_ms",
|
||||
&FrameExtractionOptions::frame_length_ms)
|
||||
.def_readwrite("dither", &FrameExtractionOptions::dither)
|
||||
.def_readwrite("preemph_coeff", &FrameExtractionOptions::preemph_coeff)
|
||||
.def_readwrite("remove_dc_offset",
|
||||
&FrameExtractionOptions::remove_dc_offset)
|
||||
.def_readwrite("window_type", &FrameExtractionOptions::window_type)
|
||||
.def_readwrite("round_to_power_of_two",
|
||||
&FrameExtractionOptions::round_to_power_of_two)
|
||||
.def_readwrite("blackman_coeff", &FrameExtractionOptions::blackman_coeff)
|
||||
.def_readwrite("snip_edges", &FrameExtractionOptions::snip_edges)
|
||||
#if 0
|
||||
.def_readwrite("allow_downsample",
|
||||
&FrameExtractionOptions::allow_downsample)
|
||||
.def_readwrite("allow_upsample", &FrameExtractionOptions::allow_upsample)
|
||||
.def_readwrite("max_feature_vectors",
|
||||
&FrameExtractionOptions::max_feature_vectors)
|
||||
#endif
|
||||
.def("__str__", [](const FrameExtractionOptions &self) -> std::string {
|
||||
return self.ToString();
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
16
kaldifeat/python/csrc/feature-window.h
Normal file
16
kaldifeat/python/csrc/feature-window.h
Normal file
@ -0,0 +1,16 @@
|
||||
// kaldifeat/python/csrc/feature-window.h
|
||||
//
|
||||
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
#ifndef KALDIFEAT_PYTHON_CSRC_FEATURE_WINDOW_H_
|
||||
#define KALDIFEAT_PYTHON_CSRC_FEATURE_WINDOW_H_
|
||||
|
||||
#include "kaldifeat/python/csrc/kaldifeat.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
void PybindFrameExtractionOptions(py::module &m);
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
#endif // KALDIFEAT_PYTHON_CSRC_FEATURE_WINDOW_H_
|
@ -7,36 +7,55 @@
|
||||
#include <chrono>
|
||||
|
||||
#include "kaldifeat/csrc/feature-fbank.h"
|
||||
#include "kaldifeat/python/csrc/feature-fbank.h"
|
||||
#include "kaldifeat/python/csrc/feature-window.h"
|
||||
#include "kaldifeat/python/csrc/mel-computations.h"
|
||||
#include "torch/torch.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
static torch::Tensor Compute(const torch::Tensor &wave,
|
||||
const FbankOptions &fbank_opts) {
|
||||
// TODO(fangjun): wrap Fbank to Python
|
||||
|
||||
Fbank fbank(fbank_opts);
|
||||
float vtln_warp = 1.0f;
|
||||
|
||||
torch::Tensor ans = fbank.ComputeFeatures(wave, vtln_warp);
|
||||
return ans;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_kaldifeat, m) {
|
||||
m.doc() = "Python wrapper for kaldifeat";
|
||||
|
||||
PybindFrameExtractionOptions(m);
|
||||
PybindMelBanksOptions(m);
|
||||
PybindFbankOptions(m);
|
||||
|
||||
m.def("compute", &Compute, py::arg("wave"), py::arg("fbank_opts"));
|
||||
|
||||
// It verifies that the reimplementation produces the same output
|
||||
// as kaldi using default paremters with dither disabled.
|
||||
m.def("test_default_parameters",
|
||||
[](const torch::Tensor &tensor) -> std::pair<torch::Tensor, double> {
|
||||
FbankOptions fbank_opts;
|
||||
fbank_opts.frame_opts.dither = 0.0f;
|
||||
// as kaldi using default parameters with dither disabled.
|
||||
m.def(
|
||||
"_compute_with_elapsed_time", // for benchmark only
|
||||
[](const torch::Tensor &wave,
|
||||
const FbankOptions &fbank_opts) -> std::pair<torch::Tensor, double> {
|
||||
std::chrono::steady_clock::time_point begin =
|
||||
std::chrono::steady_clock::now();
|
||||
|
||||
Fbank fbank(fbank_opts);
|
||||
float vtln_warp = 1.0f;
|
||||
torch::Tensor ans = Compute(wave, fbank_opts);
|
||||
|
||||
std::chrono::steady_clock::time_point begin =
|
||||
std::chrono::steady_clock::now();
|
||||
std::chrono::steady_clock::time_point end =
|
||||
std::chrono::steady_clock::now();
|
||||
|
||||
torch::Tensor ans = fbank.ComputeFeatures(tensor, vtln_warp);
|
||||
std::chrono::steady_clock::time_point end =
|
||||
std::chrono::steady_clock::now();
|
||||
double elapsed_seconds =
|
||||
std::chrono::duration_cast<std::chrono::microseconds>(end - begin)
|
||||
.count() /
|
||||
1000000.;
|
||||
double elapsed_seconds =
|
||||
std::chrono::duration_cast<std::chrono::microseconds>(end - begin)
|
||||
.count() /
|
||||
1000000.;
|
||||
|
||||
return std::make_pair(ans, elapsed_seconds);
|
||||
});
|
||||
return std::make_pair(ans, elapsed_seconds);
|
||||
},
|
||||
py::arg("wave"), py::arg("fbank_opts"));
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
27
kaldifeat/python/csrc/mel-computations.cc
Normal file
27
kaldifeat/python/csrc/mel-computations.cc
Normal file
@ -0,0 +1,27 @@
|
||||
// kaldifeat/python/csrc/mel-computations.cc
|
||||
//
|
||||
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
#include "kaldifeat/csrc/mel-computations.h"
|
||||
|
||||
#include "kaldifeat/python/csrc/feature-window.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
void PybindMelBanksOptions(py::module &m) {
|
||||
py::class_<MelBanksOptions>(m, "MelBanksOptions")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("num_bins", &MelBanksOptions::num_bins)
|
||||
.def_readwrite("low_freq", &MelBanksOptions::low_freq)
|
||||
.def_readwrite("high_freq", &MelBanksOptions::high_freq)
|
||||
.def_readwrite("vtln_low", &MelBanksOptions::vtln_low)
|
||||
.def_readwrite("vtln_high", &MelBanksOptions::vtln_high)
|
||||
.def_readwrite("debug_mel", &MelBanksOptions::debug_mel)
|
||||
.def_readwrite("htk_mode", &MelBanksOptions::htk_mode)
|
||||
.def("__str__", [](const MelBanksOptions &self) -> std::string {
|
||||
return self.ToString();
|
||||
});
|
||||
;
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
16
kaldifeat/python/csrc/mel-computations.h
Normal file
16
kaldifeat/python/csrc/mel-computations.h
Normal file
@ -0,0 +1,16 @@
|
||||
// kaldifeat/python/csrc/mel-computations.h
|
||||
//
|
||||
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
#ifndef KALDIFEAT_PYTHON_CSRC_MEL_COMPUTATIONS_H_
|
||||
#define KALDIFEAT_PYTHON_CSRC_MEL_COMPUTATIONS_H_
|
||||
|
||||
#include "kaldifeat/python/csrc/kaldifeat.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
void PybindMelBanksOptions(py::module &m);
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
#endif // KALDIFEAT_PYTHON_CSRC_MEL_COMPUTATIONS_H_
|
@ -1,49 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
from pathlib import Path
|
||||
cur_dir = Path(__file__).resolve().parent
|
||||
kaldi_feat_dir = cur_dir.parent.parent.parent
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, f'{kaldi_feat_dir}/build/lib')
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
import _kaldifeat
|
||||
|
||||
|
||||
def read_ark_txt() -> torch.Tensor:
|
||||
test_data_dir = cur_dir / 'test_data'
|
||||
filename = test_data_dir / 'abc.txt'
|
||||
features = []
|
||||
with open(filename) as f:
|
||||
for line in f:
|
||||
if '[' in line: continue
|
||||
line = line.strip('').split()
|
||||
data = [float(d) for d in line if d != ']']
|
||||
features.append(data)
|
||||
ans = torch.tensor(features)
|
||||
return ans
|
||||
|
||||
|
||||
def main():
|
||||
test_data_dir = cur_dir / 'test_data'
|
||||
filename = test_data_dir / 'abc.wav'
|
||||
with sf.SoundFile(filename) as sf_desc:
|
||||
sampling_rate = sf_desc.samplerate
|
||||
assert sampling_rate == 16000
|
||||
data = sf_desc.read(dtype=np.float32, always_2d=False)
|
||||
data *= 32768
|
||||
tensor = torch.from_numpy(data)
|
||||
ans, elapsed_seconds = _kaldifeat.test_default_parameters(tensor)
|
||||
expected = read_ark_txt()
|
||||
assert torch.allclose(ans, expected, rtol=1e-3)
|
||||
print('elapsed seconds:', elapsed_seconds)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
121
kaldifeat/python/tests/test_kaldifeat.py
Executable file
121
kaldifeat/python/tests/test_kaldifeat.py
Executable file
@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
from pathlib import Path
|
||||
cur_dir = Path(__file__).resolve().parent
|
||||
kaldi_feat_dir = cur_dir.parent.parent.parent
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, f'{kaldi_feat_dir}/build/lib')
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
import _kaldifeat
|
||||
|
||||
|
||||
def read_ark_txt() -> torch.Tensor:
|
||||
test_data_dir = cur_dir / 'test_data'
|
||||
filename = test_data_dir / 'abc.txt'
|
||||
features = []
|
||||
with open(filename) as f:
|
||||
for line in f:
|
||||
if '[' in line: continue
|
||||
line = line.strip('').split()
|
||||
data = [float(d) for d in line if d != ']']
|
||||
features.append(data)
|
||||
ans = torch.tensor(features)
|
||||
return ans
|
||||
|
||||
|
||||
def parse_str(s) -> torch.Tensor:
|
||||
'''
|
||||
Args:
|
||||
s:
|
||||
It consists of several lines. Each line contains several numbers
|
||||
separated by spaces.
|
||||
'''
|
||||
ans = []
|
||||
for line in s.strip().split('\n'):
|
||||
data = [float(d) for d in line.strip().split()]
|
||||
ans.append(data)
|
||||
return torch.tensor(ans)
|
||||
|
||||
|
||||
def read_wave() -> torch.Tensor:
|
||||
test_data_dir = cur_dir / 'test_data'
|
||||
filename = test_data_dir / 'abc.wav'
|
||||
with sf.SoundFile(filename) as sf_desc:
|
||||
sampling_rate = sf_desc.samplerate
|
||||
assert sampling_rate == 16000
|
||||
data = sf_desc.read(dtype=np.float32, always_2d=False)
|
||||
data *= 32768
|
||||
return torch.from_numpy(data)
|
||||
|
||||
|
||||
def test_and_benchmark_default_parameters():
|
||||
fbank_opts = _kaldifeat.FbankOptions()
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
|
||||
data = read_wave()
|
||||
|
||||
ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(
|
||||
data, fbank_opts)
|
||||
|
||||
expected = read_ark_txt()
|
||||
assert torch.allclose(ans, expected, rtol=1e-3)
|
||||
print('elapsed seconds:', elapsed_seconds)
|
||||
|
||||
|
||||
def test_use_energy_htk_compat_true():
|
||||
fbank_opts = _kaldifeat.FbankOptions()
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank_opts.use_energy = True
|
||||
fbank_opts.htk_compat = True
|
||||
|
||||
data = read_wave()
|
||||
|
||||
ans = _kaldifeat.compute(data, fbank_opts)
|
||||
|
||||
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:abc.scp ark,t:abc.txt
|
||||
# the first 3 rows are:
|
||||
expected_str = '''
|
||||
15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303 25.38995
|
||||
15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113 25.38996
|
||||
15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073 25.38995
|
||||
'''
|
||||
expected = parse_str(expected_str)
|
||||
assert torch.allclose(ans[:3, :], expected, rtol=1e-3)
|
||||
|
||||
|
||||
def test_use_energy_htk_compat_false():
|
||||
fbank_opts = _kaldifeat.FbankOptions()
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank_opts.use_energy = True
|
||||
fbank_opts.htk_compat = False
|
||||
|
||||
data = read_wave()
|
||||
|
||||
ans = _kaldifeat.compute(data, fbank_opts)
|
||||
|
||||
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=0 scp:abc.scp ark,t:abc.txt
|
||||
# the first 3 rows are:
|
||||
expected_str = '''
|
||||
25.38995 15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303
|
||||
25.38996 15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113
|
||||
25.38995 15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073
|
||||
'''
|
||||
expected = parse_str(expected_str)
|
||||
assert torch.allclose(ans[:3, :], expected, rtol=1e-3)
|
||||
|
||||
|
||||
def main():
|
||||
test_and_benchmark_default_parameters()
|
||||
test_use_energy_htk_compat_true()
|
||||
test_use_energy_htk_compat_false()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
83
kaldifeat/python/tests/test_options.py
Executable file
83
kaldifeat/python/tests/test_options.py
Executable file
@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
from pathlib import Path
|
||||
cur_dir = Path(__file__).resolve().parent
|
||||
kaldi_feat_dir = cur_dir.parent.parent.parent
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, f'{kaldi_feat_dir}/build/lib')
|
||||
|
||||
import torch
|
||||
import _kaldifeat
|
||||
|
||||
|
||||
def test_frame_extraction_options():
|
||||
opts = _kaldifeat.FrameExtractionOptions()
|
||||
opts.samp_freq = 220500
|
||||
opts.frame_shift_ms = 15
|
||||
opts.frame_length_ms = 40
|
||||
opts.dither = 0.1
|
||||
opts.preemph_coeff = 0.98
|
||||
opts.remove_dc_offset = False
|
||||
opts.window_type = 'hanning'
|
||||
opts.round_to_power_of_two = False
|
||||
opts.blackman_coeff = 0.422
|
||||
opts.snip_edges = False
|
||||
print(opts)
|
||||
|
||||
|
||||
def test_mel_banks_options():
|
||||
opts = _kaldifeat.MelBanksOptions()
|
||||
opts.num_bins = 23
|
||||
opts.low_freq = 21
|
||||
opts.high_freq = 8000
|
||||
opts.vtln_low = 101
|
||||
opts.vtln_high = -501
|
||||
opts.debug_mel = True
|
||||
opts.htk_mode = True
|
||||
print(opts)
|
||||
|
||||
|
||||
def test_fbank_options():
|
||||
opts = _kaldifeat.FbankOptions()
|
||||
frame_opts = opts.frame_opts
|
||||
mel_opts = opts.mel_opts
|
||||
|
||||
opts.energy_floor = 0
|
||||
opts.htk_compat = False
|
||||
opts.raw_energy = True
|
||||
opts.use_energy = False
|
||||
opts.use_log_fbank = True
|
||||
opts.use_power = True
|
||||
|
||||
frame_opts.blackman_coeff = 0.42
|
||||
frame_opts.dither = 1
|
||||
frame_opts.frame_length_ms = 25
|
||||
frame_opts.frame_shift_ms = 10
|
||||
frame_opts.preemph_coeff = 0.97
|
||||
frame_opts.remove_dc_offset = True
|
||||
frame_opts.round_to_power_of_two = True
|
||||
frame_opts.samp_freq = 16000
|
||||
frame_opts.snip_edges = True
|
||||
frame_opts.window_type = 'povey'
|
||||
|
||||
mel_opts.debug_mel = True
|
||||
mel_opts.high_freq = 0
|
||||
mel_opts.low_freq = 20
|
||||
mel_opts.num_bins = 23
|
||||
mel_opts.vtln_high = -500
|
||||
mel_opts.vtln_low = 100
|
||||
|
||||
print(opts)
|
||||
|
||||
|
||||
def main():
|
||||
# test_frame_extraction_options()
|
||||
# test_mel_banks_options()
|
||||
test_fbank_options()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user