wrap FbankOptions to Python.
This commit is contained in:
parent
53504a705c
commit
9a5567e21b
@ -15,15 +15,14 @@ template <class F>
|
|||||||
torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
|
torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
|
||||||
float vtln_warp) {
|
float vtln_warp) {
|
||||||
KALDIFEAT_ASSERT(wave.dim() == 1);
|
KALDIFEAT_ASSERT(wave.dim() == 1);
|
||||||
int32_t rows_out = NumFrames(wave.sizes()[0], computer_.GetFrameOptions());
|
|
||||||
int32_t cols_out = computer_.Dim();
|
|
||||||
|
|
||||||
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
|
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
|
||||||
|
|
||||||
torch::Tensor strided_input = GetStrided(wave, frame_opts);
|
torch::Tensor strided_input = GetStrided(wave, frame_opts);
|
||||||
|
|
||||||
if (frame_opts.dither != 0)
|
if (frame_opts.dither != 0.0f) {
|
||||||
strided_input = Dither(strided_input, frame_opts.dither);
|
strided_input = Dither(strided_input, frame_opts.dither);
|
||||||
|
}
|
||||||
|
|
||||||
if (frame_opts.remove_dc_offset) {
|
if (frame_opts.remove_dc_offset) {
|
||||||
torch::Tensor row_means = strided_input.mean(1).unsqueeze(1);
|
torch::Tensor row_means = strided_input.mean(1).unsqueeze(1);
|
||||||
@ -37,12 +36,14 @@ torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
|
|||||||
constexpr float kEps = 1.1920928955078125e-07f;
|
constexpr float kEps = 1.1920928955078125e-07f;
|
||||||
|
|
||||||
if (use_raw_log_energy) {
|
if (use_raw_log_energy) {
|
||||||
|
// it is true iff use_energy==true and row_energy==true
|
||||||
log_energy_pre_window =
|
log_energy_pre_window =
|
||||||
torch::clamp_min(strided_input.pow(2).sum(1), kEps).log();
|
torch::clamp_min(strided_input.pow(2).sum(1), kEps).log();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (frame_opts.preemph_coeff != 0.0f)
|
if (frame_opts.preemph_coeff != 0.0f) {
|
||||||
Preemphasize(frame_opts.preemph_coeff, &strided_input);
|
Preemphasize(frame_opts.preemph_coeff, &strided_input);
|
||||||
|
}
|
||||||
|
|
||||||
feature_window_function_.Apply(&strided_input);
|
feature_window_function_.Apply(&strided_input);
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,11 @@
|
|||||||
|
|
||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
std::ostream &operator<<(std::ostream &os, const FbankOptions &opts) {
|
||||||
|
os << opts.ToString();
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
FbankComputer::FbankComputer(const FbankOptions &opts) : opts_(opts) {
|
FbankComputer::FbankComputer(const FbankOptions &opts) : opts_(opts) {
|
||||||
if (opts.energy_floor > 0.0f) log_energy_floor_ = logf(opts.energy_floor);
|
if (opts.energy_floor > 0.0f) log_energy_floor_ = logf(opts.energy_floor);
|
||||||
|
|
||||||
@ -78,11 +83,6 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
|
|||||||
// Use power instead of magnitude if requested.
|
// Use power instead of magnitude if requested.
|
||||||
if (opts_.use_power) spectrum.pow_(2);
|
if (opts_.use_power) spectrum.pow_(2);
|
||||||
|
|
||||||
#if 0
|
|
||||||
int32_t mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0);
|
|
||||||
SubVector<float> mel_energies(*feature, mel_offset, opts_.mel_opts.num_bins);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// TODO(fangjun): remove the last column of spectrum
|
// TODO(fangjun): remove the last column of spectrum
|
||||||
torch::Tensor mel_energies = mel_banks.Compute(spectrum);
|
torch::Tensor mel_energies = mel_banks.Compute(spectrum);
|
||||||
if (opts_.use_log_fbank) {
|
if (opts_.use_log_fbank) {
|
||||||
@ -90,17 +90,26 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
|
|||||||
mel_energies = torch::clamp_min(mel_energies, kEps).log();
|
mel_energies = torch::clamp_min(mel_energies, kEps).log();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if use_energy is true, then we get an extra bin. That is,
|
||||||
|
// if num_mel_bins is 23, the feature will contain 24 bins.
|
||||||
|
//
|
||||||
|
// if htk_compat is false, then the 0th bin is the log energy
|
||||||
|
// if htk_compat is true, then the last bin is the log energy
|
||||||
|
|
||||||
// Copy energy as first value (or the last, if htk_compat == true).
|
// Copy energy as first value (or the last, if htk_compat == true).
|
||||||
if (opts_.use_energy) {
|
if (opts_.use_energy) {
|
||||||
#if 0
|
if (opts_.energy_floor > 0.0f) {
|
||||||
if (opts_.energy_floor > 0.0 && signal_raw_log_energy < log_energy_floor_) {
|
signal_raw_log_energy =
|
||||||
signal_raw_log_energy = log_energy_floor_;
|
torch::clamp_min(signal_raw_log_energy, log_energy_floor_);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
int32_t energy_index = opts_.htk_compat ? opts_.mel_opts.num_bins : 0;
|
|
||||||
energy_index = 0; // TODO(fangjun): fix it
|
|
||||||
|
|
||||||
mel_energies.index({"...", energy_index}) = signal_raw_log_energy;
|
signal_raw_log_energy.unsqueeze_(1);
|
||||||
|
|
||||||
|
if (opts_.htk_compat) {
|
||||||
|
mel_energies = torch::cat({mel_energies, signal_raw_log_energy}, 1);
|
||||||
|
} else {
|
||||||
|
mel_energies = torch::cat({signal_raw_log_energy, mel_energies}, 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return mel_energies;
|
return mel_energies;
|
||||||
|
|||||||
@ -20,12 +20,16 @@ struct FbankOptions {
|
|||||||
MelBanksOptions mel_opts;
|
MelBanksOptions mel_opts;
|
||||||
// append an extra dimension with energy to the filter banks
|
// append an extra dimension with energy to the filter banks
|
||||||
bool use_energy = false;
|
bool use_energy = false;
|
||||||
float energy_floor = 0.0f;
|
float energy_floor = 0.0f; // active iff use_energy==true
|
||||||
|
|
||||||
|
// If true, compute log_energy before preemphasis and windowing
|
||||||
|
// If false, compute log_energy after preemphasis ans windowing
|
||||||
|
bool raw_energy = true; // active iff use_energy==true
|
||||||
|
|
||||||
// If true, compute energy before preemphasis and windowing
|
|
||||||
bool raw_energy = true;
|
|
||||||
// If true, put energy last (if using energy)
|
// If true, put energy last (if using energy)
|
||||||
bool htk_compat = false;
|
// If false, put energy first
|
||||||
|
bool htk_compat = false; // active iff use_energy==true
|
||||||
|
|
||||||
// if true (default), produce log-filterbank, else linear
|
// if true (default), produce log-filterbank, else linear
|
||||||
bool use_log_fbank = true;
|
bool use_log_fbank = true;
|
||||||
|
|
||||||
@ -34,8 +38,28 @@ struct FbankOptions {
|
|||||||
bool use_power = true;
|
bool use_power = true;
|
||||||
|
|
||||||
FbankOptions() { mel_opts.num_bins = 23; }
|
FbankOptions() { mel_opts.num_bins = 23; }
|
||||||
|
|
||||||
|
std::string ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "frame_opts: \n";
|
||||||
|
os << frame_opts << "\n";
|
||||||
|
os << "\n";
|
||||||
|
|
||||||
|
os << "mel_opts: \n";
|
||||||
|
os << mel_opts << "\n";
|
||||||
|
|
||||||
|
os << "use_energy: " << use_energy << "\n";
|
||||||
|
os << "energy_floor: " << energy_floor << "\n";
|
||||||
|
os << "raw_energy: " << raw_energy << "\n";
|
||||||
|
os << "htk_compat: " << htk_compat << "\n";
|
||||||
|
os << "use_log_fbank: " << use_log_fbank << "\n";
|
||||||
|
os << "use_power: " << use_power << "\n";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::ostream &operator<<(std::ostream &os, const FbankOptions &opts);
|
||||||
|
|
||||||
class FbankComputer {
|
class FbankComputer {
|
||||||
public:
|
public:
|
||||||
using Options = FbankOptions;
|
using Options = FbankOptions;
|
||||||
@ -51,12 +75,15 @@ class FbankComputer {
|
|||||||
return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0);
|
return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if true, compute log_energy_pre_window but after dithering and dc removal
|
||||||
bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; }
|
bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; }
|
||||||
|
|
||||||
const FrameExtractionOptions &GetFrameOptions() const {
|
const FrameExtractionOptions &GetFrameOptions() const {
|
||||||
return opts_.frame_opts;
|
return opts_.frame_opts;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// signal_raw_log_energy is log_energy_pre_window, which is not empty
|
||||||
|
// iff NeedRawLogEnergy() returns true.
|
||||||
torch::Tensor Compute(torch::Tensor signal_raw_log_energy, float vtln_warp,
|
torch::Tensor Compute(torch::Tensor signal_raw_log_energy, float vtln_warp,
|
||||||
const torch::Tensor &signal_frame);
|
const torch::Tensor &signal_frame);
|
||||||
|
|
||||||
|
|||||||
@ -16,6 +16,11 @@
|
|||||||
|
|
||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts) {
|
||||||
|
os << opts.ToString();
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
FeatureWindowFunction::FeatureWindowFunction(
|
FeatureWindowFunction::FeatureWindowFunction(
|
||||||
const FrameExtractionOptions &opts) {
|
const FrameExtractionOptions &opts) {
|
||||||
int32_t frame_length = opts.WindowSize();
|
int32_t frame_length = opts.WindowSize();
|
||||||
|
|||||||
@ -39,9 +39,9 @@ struct FrameExtractionOptions {
|
|||||||
bool round_to_power_of_two = true;
|
bool round_to_power_of_two = true;
|
||||||
float blackman_coeff = 0.42f;
|
float blackman_coeff = 0.42f;
|
||||||
bool snip_edges = true;
|
bool snip_edges = true;
|
||||||
bool allow_downsample = false;
|
// bool allow_downsample = false;
|
||||||
bool allow_upsample = false;
|
// bool allow_upsample = false;
|
||||||
int32_t max_feature_vectors = -1;
|
// int32_t max_feature_vectors = -1;
|
||||||
|
|
||||||
int32_t WindowShift() const {
|
int32_t WindowShift() const {
|
||||||
return static_cast<int32_t>(samp_freq * 0.001f * frame_shift_ms);
|
return static_cast<int32_t>(samp_freq * 0.001f * frame_shift_ms);
|
||||||
@ -53,8 +53,29 @@ struct FrameExtractionOptions {
|
|||||||
return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(WindowSize())
|
return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(WindowSize())
|
||||||
: WindowSize());
|
: WindowSize());
|
||||||
}
|
}
|
||||||
|
std::string ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
#define KALDIFEAT_PRINT(x) os << #x << ": " << x << "\n"
|
||||||
|
KALDIFEAT_PRINT(samp_freq);
|
||||||
|
KALDIFEAT_PRINT(frame_shift_ms);
|
||||||
|
KALDIFEAT_PRINT(frame_length_ms);
|
||||||
|
KALDIFEAT_PRINT(dither);
|
||||||
|
KALDIFEAT_PRINT(preemph_coeff);
|
||||||
|
KALDIFEAT_PRINT(remove_dc_offset);
|
||||||
|
KALDIFEAT_PRINT(window_type);
|
||||||
|
KALDIFEAT_PRINT(round_to_power_of_two);
|
||||||
|
KALDIFEAT_PRINT(blackman_coeff);
|
||||||
|
KALDIFEAT_PRINT(snip_edges);
|
||||||
|
// KALDIFEAT_PRINT(allow_downsample);
|
||||||
|
// KALDIFEAT_PRINT(allow_upsample);
|
||||||
|
// KALDIFEAT_PRINT(max_feature_vectors);
|
||||||
|
#undef KALDIFEAT_PRINT
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts);
|
||||||
|
|
||||||
class FeatureWindowFunction {
|
class FeatureWindowFunction {
|
||||||
public:
|
public:
|
||||||
FeatureWindowFunction() = default;
|
FeatureWindowFunction() = default;
|
||||||
|
|||||||
@ -10,6 +10,11 @@
|
|||||||
|
|
||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts) {
|
||||||
|
os << opts.ToString();
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
float MelBanks::VtlnWarpFreq(
|
float MelBanks::VtlnWarpFreq(
|
||||||
float vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN.
|
float vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN.
|
||||||
float vtln_high_cutoff,
|
float vtln_high_cutoff,
|
||||||
|
|||||||
@ -32,8 +32,22 @@ struct MelBanksOptions {
|
|||||||
// Enables more exact compatibility with HTK, for testing purposes. Affects
|
// Enables more exact compatibility with HTK, for testing purposes. Affects
|
||||||
// mel-energy flooring and reproduces a bug in HTK.
|
// mel-energy flooring and reproduces a bug in HTK.
|
||||||
bool htk_mode = false;
|
bool htk_mode = false;
|
||||||
|
|
||||||
|
std::string ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "num_bins: " << num_bins << "\n";
|
||||||
|
os << "low_freq: " << low_freq << "\n";
|
||||||
|
os << "high_freq: " << high_freq << "\n";
|
||||||
|
os << "vtln_low: " << vtln_low << "\n";
|
||||||
|
os << "vtln_high: " << vtln_high << "\n";
|
||||||
|
os << "debug_mel: " << debug_mel << "\n";
|
||||||
|
os << "htk_mode: " << htk_mode << "\n";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts);
|
||||||
|
|
||||||
class MelBanks {
|
class MelBanks {
|
||||||
public:
|
public:
|
||||||
static inline float InverseMelScale(float mel_freq) {
|
static inline float InverseMelScale(float mel_freq) {
|
||||||
|
|||||||
@ -62,8 +62,18 @@ static void TestDither() {
|
|||||||
std::cout << (a + b * 2) << "\n";
|
std::cout << (a + b * 2) << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void TestCat() {
|
||||||
|
torch::Tensor a = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat);
|
||||||
|
torch::Tensor b = torch::arange(0, 2).reshape({2, 1}).to(torch::kFloat) * 0.1;
|
||||||
|
torch::Tensor c = torch::cat({a, b}, 1);
|
||||||
|
torch::Tensor d = torch::cat({b, a}, 1);
|
||||||
|
std::cout << a << "\n";
|
||||||
|
std::cout << b << "\n";
|
||||||
|
std::cout << c << "\n";
|
||||||
|
std::cout << d << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
// TestDither();
|
TestCat();
|
||||||
TestGetStrided();
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,4 +1,9 @@
|
|||||||
add_definitions(-DTORCH_API_INCLUDE_EXTENSION_H)
|
add_definitions(-DTORCH_API_INCLUDE_EXTENSION_H)
|
||||||
pybind11_add_module(_kaldifeat kaldifeat.cc)
|
pybind11_add_module(_kaldifeat
|
||||||
|
feature-fbank.cc
|
||||||
|
feature-window.cc
|
||||||
|
kaldifeat.cc
|
||||||
|
mel-computations.cc
|
||||||
|
)
|
||||||
target_link_libraries(_kaldifeat PRIVATE kaldifeat_core)
|
target_link_libraries(_kaldifeat PRIVATE kaldifeat_core)
|
||||||
target_link_libraries(_kaldifeat PRIVATE ${TORCH_DIR}/lib/libtorch_python.so)
|
target_link_libraries(_kaldifeat PRIVATE ${TORCH_DIR}/lib/libtorch_python.so)
|
||||||
|
|||||||
27
kaldifeat/python/csrc/feature-fbank.cc
Normal file
27
kaldifeat/python/csrc/feature-fbank.cc
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
// kaldifeat/python/csrc/feature-fbank.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
#include "kaldifeat/python/csrc/feature-fbank.h"
|
||||||
|
|
||||||
|
#include "kaldifeat/csrc/feature-fbank.h"
|
||||||
|
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
void PybindFbankOptions(py::module &m) {
|
||||||
|
py::class_<FbankOptions>(m, "FbankOptions")
|
||||||
|
.def(py::init<>())
|
||||||
|
.def_readwrite("frame_opts", &FbankOptions::frame_opts)
|
||||||
|
.def_readwrite("mel_opts", &FbankOptions::mel_opts)
|
||||||
|
.def_readwrite("use_energy", &FbankOptions::use_energy)
|
||||||
|
.def_readwrite("energy_floor", &FbankOptions::energy_floor)
|
||||||
|
.def_readwrite("raw_energy", &FbankOptions::raw_energy)
|
||||||
|
.def_readwrite("htk_compat", &FbankOptions::htk_compat)
|
||||||
|
.def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank)
|
||||||
|
.def_readwrite("use_power", &FbankOptions::use_power)
|
||||||
|
.def("__str__", [](const FbankOptions &self) -> std::string {
|
||||||
|
return self.ToString();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
||||||
16
kaldifeat/python/csrc/feature-fbank.h
Normal file
16
kaldifeat/python/csrc/feature-fbank.h
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
// kaldifeat/python/csrc/feature-fbank.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
#ifndef KALDIFEAT_PYTHON_CSRC_FEATURE_FBANK_H_
|
||||||
|
#define KALDIFEAT_PYTHON_CSRC_FEATURE_FBANK_H_
|
||||||
|
|
||||||
|
#include "kaldifeat/python/csrc/kaldifeat.h"
|
||||||
|
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
void PybindFbankOptions(py::module &m);
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
||||||
|
|
||||||
|
#endif // KALDIFEAT_PYTHON_CSRC_FEATURE_FBANK_H_
|
||||||
39
kaldifeat/python/csrc/feature-window.cc
Normal file
39
kaldifeat/python/csrc/feature-window.cc
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
// kaldifeat/python/csrc/feature-window.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
#include "kaldifeat/python/csrc/feature-window.h"
|
||||||
|
|
||||||
|
#include "kaldifeat/csrc/feature-window.h"
|
||||||
|
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
void PybindFrameExtractionOptions(py::module &m) {
|
||||||
|
py::class_<FrameExtractionOptions>(m, "FrameExtractionOptions")
|
||||||
|
.def(py::init<>())
|
||||||
|
.def_readwrite("samp_freq", &FrameExtractionOptions::samp_freq)
|
||||||
|
.def_readwrite("frame_shift_ms", &FrameExtractionOptions::frame_shift_ms)
|
||||||
|
.def_readwrite("frame_length_ms",
|
||||||
|
&FrameExtractionOptions::frame_length_ms)
|
||||||
|
.def_readwrite("dither", &FrameExtractionOptions::dither)
|
||||||
|
.def_readwrite("preemph_coeff", &FrameExtractionOptions::preemph_coeff)
|
||||||
|
.def_readwrite("remove_dc_offset",
|
||||||
|
&FrameExtractionOptions::remove_dc_offset)
|
||||||
|
.def_readwrite("window_type", &FrameExtractionOptions::window_type)
|
||||||
|
.def_readwrite("round_to_power_of_two",
|
||||||
|
&FrameExtractionOptions::round_to_power_of_two)
|
||||||
|
.def_readwrite("blackman_coeff", &FrameExtractionOptions::blackman_coeff)
|
||||||
|
.def_readwrite("snip_edges", &FrameExtractionOptions::snip_edges)
|
||||||
|
#if 0
|
||||||
|
.def_readwrite("allow_downsample",
|
||||||
|
&FrameExtractionOptions::allow_downsample)
|
||||||
|
.def_readwrite("allow_upsample", &FrameExtractionOptions::allow_upsample)
|
||||||
|
.def_readwrite("max_feature_vectors",
|
||||||
|
&FrameExtractionOptions::max_feature_vectors)
|
||||||
|
#endif
|
||||||
|
.def("__str__", [](const FrameExtractionOptions &self) -> std::string {
|
||||||
|
return self.ToString();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
||||||
16
kaldifeat/python/csrc/feature-window.h
Normal file
16
kaldifeat/python/csrc/feature-window.h
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
// kaldifeat/python/csrc/feature-window.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
#ifndef KALDIFEAT_PYTHON_CSRC_FEATURE_WINDOW_H_
|
||||||
|
#define KALDIFEAT_PYTHON_CSRC_FEATURE_WINDOW_H_
|
||||||
|
|
||||||
|
#include "kaldifeat/python/csrc/kaldifeat.h"
|
||||||
|
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
void PybindFrameExtractionOptions(py::module &m);
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
||||||
|
|
||||||
|
#endif // KALDIFEAT_PYTHON_CSRC_FEATURE_WINDOW_H_
|
||||||
@ -7,36 +7,55 @@
|
|||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
|
||||||
#include "kaldifeat/csrc/feature-fbank.h"
|
#include "kaldifeat/csrc/feature-fbank.h"
|
||||||
|
#include "kaldifeat/python/csrc/feature-fbank.h"
|
||||||
|
#include "kaldifeat/python/csrc/feature-window.h"
|
||||||
|
#include "kaldifeat/python/csrc/mel-computations.h"
|
||||||
#include "torch/torch.h"
|
#include "torch/torch.h"
|
||||||
|
|
||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
static torch::Tensor Compute(const torch::Tensor &wave,
|
||||||
|
const FbankOptions &fbank_opts) {
|
||||||
|
// TODO(fangjun): wrap Fbank to Python
|
||||||
|
|
||||||
|
Fbank fbank(fbank_opts);
|
||||||
|
float vtln_warp = 1.0f;
|
||||||
|
|
||||||
|
torch::Tensor ans = fbank.ComputeFeatures(wave, vtln_warp);
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
PYBIND11_MODULE(_kaldifeat, m) {
|
PYBIND11_MODULE(_kaldifeat, m) {
|
||||||
m.doc() = "Python wrapper for kaldifeat";
|
m.doc() = "Python wrapper for kaldifeat";
|
||||||
|
|
||||||
|
PybindFrameExtractionOptions(m);
|
||||||
|
PybindMelBanksOptions(m);
|
||||||
|
PybindFbankOptions(m);
|
||||||
|
|
||||||
|
m.def("compute", &Compute, py::arg("wave"), py::arg("fbank_opts"));
|
||||||
|
|
||||||
// It verifies that the reimplementation produces the same output
|
// It verifies that the reimplementation produces the same output
|
||||||
// as kaldi using default paremters with dither disabled.
|
// as kaldi using default parameters with dither disabled.
|
||||||
m.def("test_default_parameters",
|
m.def(
|
||||||
[](const torch::Tensor &tensor) -> std::pair<torch::Tensor, double> {
|
"_compute_with_elapsed_time", // for benchmark only
|
||||||
FbankOptions fbank_opts;
|
[](const torch::Tensor &wave,
|
||||||
fbank_opts.frame_opts.dither = 0.0f;
|
const FbankOptions &fbank_opts) -> std::pair<torch::Tensor, double> {
|
||||||
|
std::chrono::steady_clock::time_point begin =
|
||||||
|
std::chrono::steady_clock::now();
|
||||||
|
|
||||||
Fbank fbank(fbank_opts);
|
torch::Tensor ans = Compute(wave, fbank_opts);
|
||||||
float vtln_warp = 1.0f;
|
|
||||||
|
|
||||||
std::chrono::steady_clock::time_point begin =
|
std::chrono::steady_clock::time_point end =
|
||||||
std::chrono::steady_clock::now();
|
std::chrono::steady_clock::now();
|
||||||
|
|
||||||
torch::Tensor ans = fbank.ComputeFeatures(tensor, vtln_warp);
|
double elapsed_seconds =
|
||||||
std::chrono::steady_clock::time_point end =
|
std::chrono::duration_cast<std::chrono::microseconds>(end - begin)
|
||||||
std::chrono::steady_clock::now();
|
.count() /
|
||||||
double elapsed_seconds =
|
1000000.;
|
||||||
std::chrono::duration_cast<std::chrono::microseconds>(end - begin)
|
|
||||||
.count() /
|
|
||||||
1000000.;
|
|
||||||
|
|
||||||
return std::make_pair(ans, elapsed_seconds);
|
return std::make_pair(ans, elapsed_seconds);
|
||||||
});
|
},
|
||||||
|
py::arg("wave"), py::arg("fbank_opts"));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace kaldifeat
|
} // namespace kaldifeat
|
||||||
|
|||||||
27
kaldifeat/python/csrc/mel-computations.cc
Normal file
27
kaldifeat/python/csrc/mel-computations.cc
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
// kaldifeat/python/csrc/mel-computations.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
#include "kaldifeat/csrc/mel-computations.h"
|
||||||
|
|
||||||
|
#include "kaldifeat/python/csrc/feature-window.h"
|
||||||
|
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
void PybindMelBanksOptions(py::module &m) {
|
||||||
|
py::class_<MelBanksOptions>(m, "MelBanksOptions")
|
||||||
|
.def(py::init<>())
|
||||||
|
.def_readwrite("num_bins", &MelBanksOptions::num_bins)
|
||||||
|
.def_readwrite("low_freq", &MelBanksOptions::low_freq)
|
||||||
|
.def_readwrite("high_freq", &MelBanksOptions::high_freq)
|
||||||
|
.def_readwrite("vtln_low", &MelBanksOptions::vtln_low)
|
||||||
|
.def_readwrite("vtln_high", &MelBanksOptions::vtln_high)
|
||||||
|
.def_readwrite("debug_mel", &MelBanksOptions::debug_mel)
|
||||||
|
.def_readwrite("htk_mode", &MelBanksOptions::htk_mode)
|
||||||
|
.def("__str__", [](const MelBanksOptions &self) -> std::string {
|
||||||
|
return self.ToString();
|
||||||
|
});
|
||||||
|
;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
||||||
16
kaldifeat/python/csrc/mel-computations.h
Normal file
16
kaldifeat/python/csrc/mel-computations.h
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
// kaldifeat/python/csrc/mel-computations.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
#ifndef KALDIFEAT_PYTHON_CSRC_MEL_COMPUTATIONS_H_
|
||||||
|
#define KALDIFEAT_PYTHON_CSRC_MEL_COMPUTATIONS_H_
|
||||||
|
|
||||||
|
#include "kaldifeat/python/csrc/kaldifeat.h"
|
||||||
|
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
void PybindMelBanksOptions(py::module &m);
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
||||||
|
|
||||||
|
#endif // KALDIFEAT_PYTHON_CSRC_MEL_COMPUTATIONS_H_
|
||||||
@ -1,49 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
#
|
|
||||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
cur_dir = Path(__file__).resolve().parent
|
|
||||||
kaldi_feat_dir = cur_dir.parent.parent.parent
|
|
||||||
|
|
||||||
import sys
|
|
||||||
sys.path.insert(0, f'{kaldi_feat_dir}/build/lib')
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import soundfile as sf
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import _kaldifeat
|
|
||||||
|
|
||||||
|
|
||||||
def read_ark_txt() -> torch.Tensor:
|
|
||||||
test_data_dir = cur_dir / 'test_data'
|
|
||||||
filename = test_data_dir / 'abc.txt'
|
|
||||||
features = []
|
|
||||||
with open(filename) as f:
|
|
||||||
for line in f:
|
|
||||||
if '[' in line: continue
|
|
||||||
line = line.strip('').split()
|
|
||||||
data = [float(d) for d in line if d != ']']
|
|
||||||
features.append(data)
|
|
||||||
ans = torch.tensor(features)
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
test_data_dir = cur_dir / 'test_data'
|
|
||||||
filename = test_data_dir / 'abc.wav'
|
|
||||||
with sf.SoundFile(filename) as sf_desc:
|
|
||||||
sampling_rate = sf_desc.samplerate
|
|
||||||
assert sampling_rate == 16000
|
|
||||||
data = sf_desc.read(dtype=np.float32, always_2d=False)
|
|
||||||
data *= 32768
|
|
||||||
tensor = torch.from_numpy(data)
|
|
||||||
ans, elapsed_seconds = _kaldifeat.test_default_parameters(tensor)
|
|
||||||
expected = read_ark_txt()
|
|
||||||
assert torch.allclose(ans, expected, rtol=1e-3)
|
|
||||||
print('elapsed seconds:', elapsed_seconds)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
121
kaldifeat/python/tests/test_kaldifeat.py
Executable file
121
kaldifeat/python/tests/test_kaldifeat.py
Executable file
@ -0,0 +1,121 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
cur_dir = Path(__file__).resolve().parent
|
||||||
|
kaldi_feat_dir = cur_dir.parent.parent.parent
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, f'{kaldi_feat_dir}/build/lib')
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import _kaldifeat
|
||||||
|
|
||||||
|
|
||||||
|
def read_ark_txt() -> torch.Tensor:
|
||||||
|
test_data_dir = cur_dir / 'test_data'
|
||||||
|
filename = test_data_dir / 'abc.txt'
|
||||||
|
features = []
|
||||||
|
with open(filename) as f:
|
||||||
|
for line in f:
|
||||||
|
if '[' in line: continue
|
||||||
|
line = line.strip('').split()
|
||||||
|
data = [float(d) for d in line if d != ']']
|
||||||
|
features.append(data)
|
||||||
|
ans = torch.tensor(features)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def parse_str(s) -> torch.Tensor:
|
||||||
|
'''
|
||||||
|
Args:
|
||||||
|
s:
|
||||||
|
It consists of several lines. Each line contains several numbers
|
||||||
|
separated by spaces.
|
||||||
|
'''
|
||||||
|
ans = []
|
||||||
|
for line in s.strip().split('\n'):
|
||||||
|
data = [float(d) for d in line.strip().split()]
|
||||||
|
ans.append(data)
|
||||||
|
return torch.tensor(ans)
|
||||||
|
|
||||||
|
|
||||||
|
def read_wave() -> torch.Tensor:
|
||||||
|
test_data_dir = cur_dir / 'test_data'
|
||||||
|
filename = test_data_dir / 'abc.wav'
|
||||||
|
with sf.SoundFile(filename) as sf_desc:
|
||||||
|
sampling_rate = sf_desc.samplerate
|
||||||
|
assert sampling_rate == 16000
|
||||||
|
data = sf_desc.read(dtype=np.float32, always_2d=False)
|
||||||
|
data *= 32768
|
||||||
|
return torch.from_numpy(data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_and_benchmark_default_parameters():
|
||||||
|
fbank_opts = _kaldifeat.FbankOptions()
|
||||||
|
fbank_opts.frame_opts.dither = 0
|
||||||
|
|
||||||
|
data = read_wave()
|
||||||
|
|
||||||
|
ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(
|
||||||
|
data, fbank_opts)
|
||||||
|
|
||||||
|
expected = read_ark_txt()
|
||||||
|
assert torch.allclose(ans, expected, rtol=1e-3)
|
||||||
|
print('elapsed seconds:', elapsed_seconds)
|
||||||
|
|
||||||
|
|
||||||
|
def test_use_energy_htk_compat_true():
|
||||||
|
fbank_opts = _kaldifeat.FbankOptions()
|
||||||
|
fbank_opts.frame_opts.dither = 0
|
||||||
|
fbank_opts.use_energy = True
|
||||||
|
fbank_opts.htk_compat = True
|
||||||
|
|
||||||
|
data = read_wave()
|
||||||
|
|
||||||
|
ans = _kaldifeat.compute(data, fbank_opts)
|
||||||
|
|
||||||
|
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:abc.scp ark,t:abc.txt
|
||||||
|
# the first 3 rows are:
|
||||||
|
expected_str = '''
|
||||||
|
15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303 25.38995
|
||||||
|
15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113 25.38996
|
||||||
|
15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073 25.38995
|
||||||
|
'''
|
||||||
|
expected = parse_str(expected_str)
|
||||||
|
assert torch.allclose(ans[:3, :], expected, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_use_energy_htk_compat_false():
|
||||||
|
fbank_opts = _kaldifeat.FbankOptions()
|
||||||
|
fbank_opts.frame_opts.dither = 0
|
||||||
|
fbank_opts.use_energy = True
|
||||||
|
fbank_opts.htk_compat = False
|
||||||
|
|
||||||
|
data = read_wave()
|
||||||
|
|
||||||
|
ans = _kaldifeat.compute(data, fbank_opts)
|
||||||
|
|
||||||
|
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=0 scp:abc.scp ark,t:abc.txt
|
||||||
|
# the first 3 rows are:
|
||||||
|
expected_str = '''
|
||||||
|
25.38995 15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303
|
||||||
|
25.38996 15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113
|
||||||
|
25.38995 15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073
|
||||||
|
'''
|
||||||
|
expected = parse_str(expected_str)
|
||||||
|
assert torch.allclose(ans[:3, :], expected, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_and_benchmark_default_parameters()
|
||||||
|
test_use_energy_htk_compat_true()
|
||||||
|
test_use_energy_htk_compat_false()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
83
kaldifeat/python/tests/test_options.py
Executable file
83
kaldifeat/python/tests/test_options.py
Executable file
@ -0,0 +1,83 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
cur_dir = Path(__file__).resolve().parent
|
||||||
|
kaldi_feat_dir = cur_dir.parent.parent.parent
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, f'{kaldi_feat_dir}/build/lib')
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import _kaldifeat
|
||||||
|
|
||||||
|
|
||||||
|
def test_frame_extraction_options():
|
||||||
|
opts = _kaldifeat.FrameExtractionOptions()
|
||||||
|
opts.samp_freq = 220500
|
||||||
|
opts.frame_shift_ms = 15
|
||||||
|
opts.frame_length_ms = 40
|
||||||
|
opts.dither = 0.1
|
||||||
|
opts.preemph_coeff = 0.98
|
||||||
|
opts.remove_dc_offset = False
|
||||||
|
opts.window_type = 'hanning'
|
||||||
|
opts.round_to_power_of_two = False
|
||||||
|
opts.blackman_coeff = 0.422
|
||||||
|
opts.snip_edges = False
|
||||||
|
print(opts)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mel_banks_options():
|
||||||
|
opts = _kaldifeat.MelBanksOptions()
|
||||||
|
opts.num_bins = 23
|
||||||
|
opts.low_freq = 21
|
||||||
|
opts.high_freq = 8000
|
||||||
|
opts.vtln_low = 101
|
||||||
|
opts.vtln_high = -501
|
||||||
|
opts.debug_mel = True
|
||||||
|
opts.htk_mode = True
|
||||||
|
print(opts)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fbank_options():
|
||||||
|
opts = _kaldifeat.FbankOptions()
|
||||||
|
frame_opts = opts.frame_opts
|
||||||
|
mel_opts = opts.mel_opts
|
||||||
|
|
||||||
|
opts.energy_floor = 0
|
||||||
|
opts.htk_compat = False
|
||||||
|
opts.raw_energy = True
|
||||||
|
opts.use_energy = False
|
||||||
|
opts.use_log_fbank = True
|
||||||
|
opts.use_power = True
|
||||||
|
|
||||||
|
frame_opts.blackman_coeff = 0.42
|
||||||
|
frame_opts.dither = 1
|
||||||
|
frame_opts.frame_length_ms = 25
|
||||||
|
frame_opts.frame_shift_ms = 10
|
||||||
|
frame_opts.preemph_coeff = 0.97
|
||||||
|
frame_opts.remove_dc_offset = True
|
||||||
|
frame_opts.round_to_power_of_two = True
|
||||||
|
frame_opts.samp_freq = 16000
|
||||||
|
frame_opts.snip_edges = True
|
||||||
|
frame_opts.window_type = 'povey'
|
||||||
|
|
||||||
|
mel_opts.debug_mel = True
|
||||||
|
mel_opts.high_freq = 0
|
||||||
|
mel_opts.low_freq = 20
|
||||||
|
mel_opts.num_bins = 23
|
||||||
|
mel_opts.vtln_high = -500
|
||||||
|
mel_opts.vtln_low = 100
|
||||||
|
|
||||||
|
print(opts)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# test_frame_extraction_options()
|
||||||
|
# test_mel_banks_options()
|
||||||
|
test_fbank_options()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Loading…
x
Reference in New Issue
Block a user