From 9a5567e21b3509522c5b1af25185c48c65bc7d30 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 27 Feb 2021 22:52:39 +0800 Subject: [PATCH] wrap FbankOptions to Python. --- kaldifeat/csrc/feature-common-inl.h | 9 +- kaldifeat/csrc/feature-fbank.cc | 33 +++-- kaldifeat/csrc/feature-fbank.h | 35 ++++- kaldifeat/csrc/feature-window.cc | 5 + kaldifeat/csrc/feature-window.h | 27 +++- kaldifeat/csrc/mel-computations.cc | 5 + kaldifeat/csrc/mel-computations.h | 14 ++ kaldifeat/csrc/test_kaldifeat.cc | 14 +- kaldifeat/python/csrc/CMakeLists.txt | 7 +- kaldifeat/python/csrc/feature-fbank.cc | 27 ++++ kaldifeat/python/csrc/feature-fbank.h | 16 +++ kaldifeat/python/csrc/feature-window.cc | 39 ++++++ kaldifeat/python/csrc/feature-window.h | 16 +++ kaldifeat/python/csrc/kaldifeat.cc | 55 +++++--- kaldifeat/python/csrc/mel-computations.cc | 27 ++++ kaldifeat/python/csrc/mel-computations.h | 16 +++ .../python/tests/test_default_parameters.py | 49 ------- kaldifeat/python/tests/test_kaldifeat.py | 121 ++++++++++++++++++ kaldifeat/python/tests/test_options.py | 83 ++++++++++++ 19 files changed, 505 insertions(+), 93 deletions(-) create mode 100644 kaldifeat/python/csrc/feature-fbank.cc create mode 100644 kaldifeat/python/csrc/feature-fbank.h create mode 100644 kaldifeat/python/csrc/feature-window.cc create mode 100644 kaldifeat/python/csrc/feature-window.h create mode 100644 kaldifeat/python/csrc/mel-computations.cc create mode 100644 kaldifeat/python/csrc/mel-computations.h delete mode 100755 kaldifeat/python/tests/test_default_parameters.py create mode 100755 kaldifeat/python/tests/test_kaldifeat.py create mode 100755 kaldifeat/python/tests/test_options.py diff --git a/kaldifeat/csrc/feature-common-inl.h b/kaldifeat/csrc/feature-common-inl.h index 243b7dd..0c9e3c4 100644 --- a/kaldifeat/csrc/feature-common-inl.h +++ b/kaldifeat/csrc/feature-common-inl.h @@ -15,15 +15,14 @@ template torch::Tensor OfflineFeatureTpl::ComputeFeatures(const torch::Tensor &wave, float vtln_warp) { KALDIFEAT_ASSERT(wave.dim() == 1); - int32_t rows_out = NumFrames(wave.sizes()[0], computer_.GetFrameOptions()); - int32_t cols_out = computer_.Dim(); const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions(); torch::Tensor strided_input = GetStrided(wave, frame_opts); - if (frame_opts.dither != 0) + if (frame_opts.dither != 0.0f) { strided_input = Dither(strided_input, frame_opts.dither); + } if (frame_opts.remove_dc_offset) { torch::Tensor row_means = strided_input.mean(1).unsqueeze(1); @@ -37,12 +36,14 @@ torch::Tensor OfflineFeatureTpl::ComputeFeatures(const torch::Tensor &wave, constexpr float kEps = 1.1920928955078125e-07f; if (use_raw_log_energy) { + // it is true iff use_energy==true and row_energy==true log_energy_pre_window = torch::clamp_min(strided_input.pow(2).sum(1), kEps).log(); } - if (frame_opts.preemph_coeff != 0.0f) + if (frame_opts.preemph_coeff != 0.0f) { Preemphasize(frame_opts.preemph_coeff, &strided_input); + } feature_window_function_.Apply(&strided_input); diff --git a/kaldifeat/csrc/feature-fbank.cc b/kaldifeat/csrc/feature-fbank.cc index b21083c..488e441 100644 --- a/kaldifeat/csrc/feature-fbank.cc +++ b/kaldifeat/csrc/feature-fbank.cc @@ -13,6 +13,11 @@ namespace kaldifeat { +std::ostream &operator<<(std::ostream &os, const FbankOptions &opts) { + os << opts.ToString(); + return os; +} + FbankComputer::FbankComputer(const FbankOptions &opts) : opts_(opts) { if (opts.energy_floor > 0.0f) log_energy_floor_ = logf(opts.energy_floor); @@ -78,11 +83,6 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy, // Use power instead of magnitude if requested. if (opts_.use_power) spectrum.pow_(2); -#if 0 - int32_t mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0); - SubVector mel_energies(*feature, mel_offset, opts_.mel_opts.num_bins); -#endif - // TODO(fangjun): remove the last column of spectrum torch::Tensor mel_energies = mel_banks.Compute(spectrum); if (opts_.use_log_fbank) { @@ -90,17 +90,26 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy, mel_energies = torch::clamp_min(mel_energies, kEps).log(); } + // if use_energy is true, then we get an extra bin. That is, + // if num_mel_bins is 23, the feature will contain 24 bins. + // + // if htk_compat is false, then the 0th bin is the log energy + // if htk_compat is true, then the last bin is the log energy + // Copy energy as first value (or the last, if htk_compat == true). if (opts_.use_energy) { -#if 0 - if (opts_.energy_floor > 0.0 && signal_raw_log_energy < log_energy_floor_) { - signal_raw_log_energy = log_energy_floor_; + if (opts_.energy_floor > 0.0f) { + signal_raw_log_energy = + torch::clamp_min(signal_raw_log_energy, log_energy_floor_); } -#endif - int32_t energy_index = opts_.htk_compat ? opts_.mel_opts.num_bins : 0; - energy_index = 0; // TODO(fangjun): fix it - mel_energies.index({"...", energy_index}) = signal_raw_log_energy; + signal_raw_log_energy.unsqueeze_(1); + + if (opts_.htk_compat) { + mel_energies = torch::cat({mel_energies, signal_raw_log_energy}, 1); + } else { + mel_energies = torch::cat({signal_raw_log_energy, mel_energies}, 1); + } } return mel_energies; diff --git a/kaldifeat/csrc/feature-fbank.h b/kaldifeat/csrc/feature-fbank.h index b1da6bb..f937d05 100644 --- a/kaldifeat/csrc/feature-fbank.h +++ b/kaldifeat/csrc/feature-fbank.h @@ -20,12 +20,16 @@ struct FbankOptions { MelBanksOptions mel_opts; // append an extra dimension with energy to the filter banks bool use_energy = false; - float energy_floor = 0.0f; + float energy_floor = 0.0f; // active iff use_energy==true + + // If true, compute log_energy before preemphasis and windowing + // If false, compute log_energy after preemphasis ans windowing + bool raw_energy = true; // active iff use_energy==true - // If true, compute energy before preemphasis and windowing - bool raw_energy = true; // If true, put energy last (if using energy) - bool htk_compat = false; + // If false, put energy first + bool htk_compat = false; // active iff use_energy==true + // if true (default), produce log-filterbank, else linear bool use_log_fbank = true; @@ -34,8 +38,28 @@ struct FbankOptions { bool use_power = true; FbankOptions() { mel_opts.num_bins = 23; } + + std::string ToString() const { + std::ostringstream os; + os << "frame_opts: \n"; + os << frame_opts << "\n"; + os << "\n"; + + os << "mel_opts: \n"; + os << mel_opts << "\n"; + + os << "use_energy: " << use_energy << "\n"; + os << "energy_floor: " << energy_floor << "\n"; + os << "raw_energy: " << raw_energy << "\n"; + os << "htk_compat: " << htk_compat << "\n"; + os << "use_log_fbank: " << use_log_fbank << "\n"; + os << "use_power: " << use_power << "\n"; + return os.str(); + } }; +std::ostream &operator<<(std::ostream &os, const FbankOptions &opts); + class FbankComputer { public: using Options = FbankOptions; @@ -51,12 +75,15 @@ class FbankComputer { return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0); } + // if true, compute log_energy_pre_window but after dithering and dc removal bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; } const FrameExtractionOptions &GetFrameOptions() const { return opts_.frame_opts; } + // signal_raw_log_energy is log_energy_pre_window, which is not empty + // iff NeedRawLogEnergy() returns true. torch::Tensor Compute(torch::Tensor signal_raw_log_energy, float vtln_warp, const torch::Tensor &signal_frame); diff --git a/kaldifeat/csrc/feature-window.cc b/kaldifeat/csrc/feature-window.cc index 93c749d..bcb4cd9 100644 --- a/kaldifeat/csrc/feature-window.cc +++ b/kaldifeat/csrc/feature-window.cc @@ -16,6 +16,11 @@ namespace kaldifeat { +std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts) { + os << opts.ToString(); + return os; +} + FeatureWindowFunction::FeatureWindowFunction( const FrameExtractionOptions &opts) { int32_t frame_length = opts.WindowSize(); diff --git a/kaldifeat/csrc/feature-window.h b/kaldifeat/csrc/feature-window.h index a4a3ed8..09caa5d 100644 --- a/kaldifeat/csrc/feature-window.h +++ b/kaldifeat/csrc/feature-window.h @@ -39,9 +39,9 @@ struct FrameExtractionOptions { bool round_to_power_of_two = true; float blackman_coeff = 0.42f; bool snip_edges = true; - bool allow_downsample = false; - bool allow_upsample = false; - int32_t max_feature_vectors = -1; + // bool allow_downsample = false; + // bool allow_upsample = false; + // int32_t max_feature_vectors = -1; int32_t WindowShift() const { return static_cast(samp_freq * 0.001f * frame_shift_ms); @@ -53,8 +53,29 @@ struct FrameExtractionOptions { return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(WindowSize()) : WindowSize()); } + std::string ToString() const { + std::ostringstream os; +#define KALDIFEAT_PRINT(x) os << #x << ": " << x << "\n" + KALDIFEAT_PRINT(samp_freq); + KALDIFEAT_PRINT(frame_shift_ms); + KALDIFEAT_PRINT(frame_length_ms); + KALDIFEAT_PRINT(dither); + KALDIFEAT_PRINT(preemph_coeff); + KALDIFEAT_PRINT(remove_dc_offset); + KALDIFEAT_PRINT(window_type); + KALDIFEAT_PRINT(round_to_power_of_two); + KALDIFEAT_PRINT(blackman_coeff); + KALDIFEAT_PRINT(snip_edges); + // KALDIFEAT_PRINT(allow_downsample); + // KALDIFEAT_PRINT(allow_upsample); + // KALDIFEAT_PRINT(max_feature_vectors); +#undef KALDIFEAT_PRINT + return os.str(); + } }; +std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts); + class FeatureWindowFunction { public: FeatureWindowFunction() = default; diff --git a/kaldifeat/csrc/mel-computations.cc b/kaldifeat/csrc/mel-computations.cc index 1c711a8..22f01eb 100644 --- a/kaldifeat/csrc/mel-computations.cc +++ b/kaldifeat/csrc/mel-computations.cc @@ -10,6 +10,11 @@ namespace kaldifeat { +std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts) { + os << opts.ToString(); + return os; +} + float MelBanks::VtlnWarpFreq( float vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN. float vtln_high_cutoff, diff --git a/kaldifeat/csrc/mel-computations.h b/kaldifeat/csrc/mel-computations.h index b89d74b..775ba9a 100644 --- a/kaldifeat/csrc/mel-computations.h +++ b/kaldifeat/csrc/mel-computations.h @@ -32,8 +32,22 @@ struct MelBanksOptions { // Enables more exact compatibility with HTK, for testing purposes. Affects // mel-energy flooring and reproduces a bug in HTK. bool htk_mode = false; + + std::string ToString() const { + std::ostringstream os; + os << "num_bins: " << num_bins << "\n"; + os << "low_freq: " << low_freq << "\n"; + os << "high_freq: " << high_freq << "\n"; + os << "vtln_low: " << vtln_low << "\n"; + os << "vtln_high: " << vtln_high << "\n"; + os << "debug_mel: " << debug_mel << "\n"; + os << "htk_mode: " << htk_mode << "\n"; + return os.str(); + } }; +std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts); + class MelBanks { public: static inline float InverseMelScale(float mel_freq) { diff --git a/kaldifeat/csrc/test_kaldifeat.cc b/kaldifeat/csrc/test_kaldifeat.cc index 6418ec8..528a7cf 100644 --- a/kaldifeat/csrc/test_kaldifeat.cc +++ b/kaldifeat/csrc/test_kaldifeat.cc @@ -62,8 +62,18 @@ static void TestDither() { std::cout << (a + b * 2) << "\n"; } +static void TestCat() { + torch::Tensor a = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat); + torch::Tensor b = torch::arange(0, 2).reshape({2, 1}).to(torch::kFloat) * 0.1; + torch::Tensor c = torch::cat({a, b}, 1); + torch::Tensor d = torch::cat({b, a}, 1); + std::cout << a << "\n"; + std::cout << b << "\n"; + std::cout << c << "\n"; + std::cout << d << "\n"; +} + int main() { - // TestDither(); - TestGetStrided(); + TestCat(); return 0; } diff --git a/kaldifeat/python/csrc/CMakeLists.txt b/kaldifeat/python/csrc/CMakeLists.txt index 8c27ab7..17afb2d 100644 --- a/kaldifeat/python/csrc/CMakeLists.txt +++ b/kaldifeat/python/csrc/CMakeLists.txt @@ -1,4 +1,9 @@ add_definitions(-DTORCH_API_INCLUDE_EXTENSION_H) -pybind11_add_module(_kaldifeat kaldifeat.cc) +pybind11_add_module(_kaldifeat + feature-fbank.cc + feature-window.cc + kaldifeat.cc + mel-computations.cc +) target_link_libraries(_kaldifeat PRIVATE kaldifeat_core) target_link_libraries(_kaldifeat PRIVATE ${TORCH_DIR}/lib/libtorch_python.so) diff --git a/kaldifeat/python/csrc/feature-fbank.cc b/kaldifeat/python/csrc/feature-fbank.cc new file mode 100644 index 0000000..bd068e2 --- /dev/null +++ b/kaldifeat/python/csrc/feature-fbank.cc @@ -0,0 +1,27 @@ +// kaldifeat/python/csrc/feature-fbank.cc +// +// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +#include "kaldifeat/python/csrc/feature-fbank.h" + +#include "kaldifeat/csrc/feature-fbank.h" + +namespace kaldifeat { + +void PybindFbankOptions(py::module &m) { + py::class_(m, "FbankOptions") + .def(py::init<>()) + .def_readwrite("frame_opts", &FbankOptions::frame_opts) + .def_readwrite("mel_opts", &FbankOptions::mel_opts) + .def_readwrite("use_energy", &FbankOptions::use_energy) + .def_readwrite("energy_floor", &FbankOptions::energy_floor) + .def_readwrite("raw_energy", &FbankOptions::raw_energy) + .def_readwrite("htk_compat", &FbankOptions::htk_compat) + .def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank) + .def_readwrite("use_power", &FbankOptions::use_power) + .def("__str__", [](const FbankOptions &self) -> std::string { + return self.ToString(); + }); +} + +} // namespace kaldifeat diff --git a/kaldifeat/python/csrc/feature-fbank.h b/kaldifeat/python/csrc/feature-fbank.h new file mode 100644 index 0000000..4e0b135 --- /dev/null +++ b/kaldifeat/python/csrc/feature-fbank.h @@ -0,0 +1,16 @@ +// kaldifeat/python/csrc/feature-fbank.h +// +// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +#ifndef KALDIFEAT_PYTHON_CSRC_FEATURE_FBANK_H_ +#define KALDIFEAT_PYTHON_CSRC_FEATURE_FBANK_H_ + +#include "kaldifeat/python/csrc/kaldifeat.h" + +namespace kaldifeat { + +void PybindFbankOptions(py::module &m); + +} // namespace kaldifeat + +#endif // KALDIFEAT_PYTHON_CSRC_FEATURE_FBANK_H_ diff --git a/kaldifeat/python/csrc/feature-window.cc b/kaldifeat/python/csrc/feature-window.cc new file mode 100644 index 0000000..bc3267f --- /dev/null +++ b/kaldifeat/python/csrc/feature-window.cc @@ -0,0 +1,39 @@ +// kaldifeat/python/csrc/feature-window.cc +// +// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +#include "kaldifeat/python/csrc/feature-window.h" + +#include "kaldifeat/csrc/feature-window.h" + +namespace kaldifeat { + +void PybindFrameExtractionOptions(py::module &m) { + py::class_(m, "FrameExtractionOptions") + .def(py::init<>()) + .def_readwrite("samp_freq", &FrameExtractionOptions::samp_freq) + .def_readwrite("frame_shift_ms", &FrameExtractionOptions::frame_shift_ms) + .def_readwrite("frame_length_ms", + &FrameExtractionOptions::frame_length_ms) + .def_readwrite("dither", &FrameExtractionOptions::dither) + .def_readwrite("preemph_coeff", &FrameExtractionOptions::preemph_coeff) + .def_readwrite("remove_dc_offset", + &FrameExtractionOptions::remove_dc_offset) + .def_readwrite("window_type", &FrameExtractionOptions::window_type) + .def_readwrite("round_to_power_of_two", + &FrameExtractionOptions::round_to_power_of_two) + .def_readwrite("blackman_coeff", &FrameExtractionOptions::blackman_coeff) + .def_readwrite("snip_edges", &FrameExtractionOptions::snip_edges) +#if 0 + .def_readwrite("allow_downsample", + &FrameExtractionOptions::allow_downsample) + .def_readwrite("allow_upsample", &FrameExtractionOptions::allow_upsample) + .def_readwrite("max_feature_vectors", + &FrameExtractionOptions::max_feature_vectors) +#endif + .def("__str__", [](const FrameExtractionOptions &self) -> std::string { + return self.ToString(); + }); +} + +} // namespace kaldifeat diff --git a/kaldifeat/python/csrc/feature-window.h b/kaldifeat/python/csrc/feature-window.h new file mode 100644 index 0000000..860d83e --- /dev/null +++ b/kaldifeat/python/csrc/feature-window.h @@ -0,0 +1,16 @@ +// kaldifeat/python/csrc/feature-window.h +// +// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +#ifndef KALDIFEAT_PYTHON_CSRC_FEATURE_WINDOW_H_ +#define KALDIFEAT_PYTHON_CSRC_FEATURE_WINDOW_H_ + +#include "kaldifeat/python/csrc/kaldifeat.h" + +namespace kaldifeat { + +void PybindFrameExtractionOptions(py::module &m); + +} // namespace kaldifeat + +#endif // KALDIFEAT_PYTHON_CSRC_FEATURE_WINDOW_H_ diff --git a/kaldifeat/python/csrc/kaldifeat.cc b/kaldifeat/python/csrc/kaldifeat.cc index 24851c4..80bc08e 100644 --- a/kaldifeat/python/csrc/kaldifeat.cc +++ b/kaldifeat/python/csrc/kaldifeat.cc @@ -7,36 +7,55 @@ #include #include "kaldifeat/csrc/feature-fbank.h" +#include "kaldifeat/python/csrc/feature-fbank.h" +#include "kaldifeat/python/csrc/feature-window.h" +#include "kaldifeat/python/csrc/mel-computations.h" #include "torch/torch.h" namespace kaldifeat { +static torch::Tensor Compute(const torch::Tensor &wave, + const FbankOptions &fbank_opts) { + // TODO(fangjun): wrap Fbank to Python + + Fbank fbank(fbank_opts); + float vtln_warp = 1.0f; + + torch::Tensor ans = fbank.ComputeFeatures(wave, vtln_warp); + return ans; +} + PYBIND11_MODULE(_kaldifeat, m) { m.doc() = "Python wrapper for kaldifeat"; + PybindFrameExtractionOptions(m); + PybindMelBanksOptions(m); + PybindFbankOptions(m); + + m.def("compute", &Compute, py::arg("wave"), py::arg("fbank_opts")); + // It verifies that the reimplementation produces the same output - // as kaldi using default paremters with dither disabled. - m.def("test_default_parameters", - [](const torch::Tensor &tensor) -> std::pair { - FbankOptions fbank_opts; - fbank_opts.frame_opts.dither = 0.0f; + // as kaldi using default parameters with dither disabled. + m.def( + "_compute_with_elapsed_time", // for benchmark only + [](const torch::Tensor &wave, + const FbankOptions &fbank_opts) -> std::pair { + std::chrono::steady_clock::time_point begin = + std::chrono::steady_clock::now(); - Fbank fbank(fbank_opts); - float vtln_warp = 1.0f; + torch::Tensor ans = Compute(wave, fbank_opts); - std::chrono::steady_clock::time_point begin = - std::chrono::steady_clock::now(); + std::chrono::steady_clock::time_point end = + std::chrono::steady_clock::now(); - torch::Tensor ans = fbank.ComputeFeatures(tensor, vtln_warp); - std::chrono::steady_clock::time_point end = - std::chrono::steady_clock::now(); - double elapsed_seconds = - std::chrono::duration_cast(end - begin) - .count() / - 1000000.; + double elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000000.; - return std::make_pair(ans, elapsed_seconds); - }); + return std::make_pair(ans, elapsed_seconds); + }, + py::arg("wave"), py::arg("fbank_opts")); } } // namespace kaldifeat diff --git a/kaldifeat/python/csrc/mel-computations.cc b/kaldifeat/python/csrc/mel-computations.cc new file mode 100644 index 0000000..24a793a --- /dev/null +++ b/kaldifeat/python/csrc/mel-computations.cc @@ -0,0 +1,27 @@ +// kaldifeat/python/csrc/mel-computations.cc +// +// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +#include "kaldifeat/csrc/mel-computations.h" + +#include "kaldifeat/python/csrc/feature-window.h" + +namespace kaldifeat { + +void PybindMelBanksOptions(py::module &m) { + py::class_(m, "MelBanksOptions") + .def(py::init<>()) + .def_readwrite("num_bins", &MelBanksOptions::num_bins) + .def_readwrite("low_freq", &MelBanksOptions::low_freq) + .def_readwrite("high_freq", &MelBanksOptions::high_freq) + .def_readwrite("vtln_low", &MelBanksOptions::vtln_low) + .def_readwrite("vtln_high", &MelBanksOptions::vtln_high) + .def_readwrite("debug_mel", &MelBanksOptions::debug_mel) + .def_readwrite("htk_mode", &MelBanksOptions::htk_mode) + .def("__str__", [](const MelBanksOptions &self) -> std::string { + return self.ToString(); + }); + ; +} + +} // namespace kaldifeat diff --git a/kaldifeat/python/csrc/mel-computations.h b/kaldifeat/python/csrc/mel-computations.h new file mode 100644 index 0000000..0caaa0a --- /dev/null +++ b/kaldifeat/python/csrc/mel-computations.h @@ -0,0 +1,16 @@ +// kaldifeat/python/csrc/mel-computations.h +// +// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +#ifndef KALDIFEAT_PYTHON_CSRC_MEL_COMPUTATIONS_H_ +#define KALDIFEAT_PYTHON_CSRC_MEL_COMPUTATIONS_H_ + +#include "kaldifeat/python/csrc/kaldifeat.h" + +namespace kaldifeat { + +void PybindMelBanksOptions(py::module &m); + +} // namespace kaldifeat + +#endif // KALDIFEAT_PYTHON_CSRC_MEL_COMPUTATIONS_H_ diff --git a/kaldifeat/python/tests/test_default_parameters.py b/kaldifeat/python/tests/test_default_parameters.py deleted file mode 100755 index 9f57c27..0000000 --- a/kaldifeat/python/tests/test_default_parameters.py +++ /dev/null @@ -1,49 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) - -from pathlib import Path -cur_dir = Path(__file__).resolve().parent -kaldi_feat_dir = cur_dir.parent.parent.parent - -import sys -sys.path.insert(0, f'{kaldi_feat_dir}/build/lib') - -import numpy as np -import soundfile as sf -import torch - -import _kaldifeat - - -def read_ark_txt() -> torch.Tensor: - test_data_dir = cur_dir / 'test_data' - filename = test_data_dir / 'abc.txt' - features = [] - with open(filename) as f: - for line in f: - if '[' in line: continue - line = line.strip('').split() - data = [float(d) for d in line if d != ']'] - features.append(data) - ans = torch.tensor(features) - return ans - - -def main(): - test_data_dir = cur_dir / 'test_data' - filename = test_data_dir / 'abc.wav' - with sf.SoundFile(filename) as sf_desc: - sampling_rate = sf_desc.samplerate - assert sampling_rate == 16000 - data = sf_desc.read(dtype=np.float32, always_2d=False) - data *= 32768 - tensor = torch.from_numpy(data) - ans, elapsed_seconds = _kaldifeat.test_default_parameters(tensor) - expected = read_ark_txt() - assert torch.allclose(ans, expected, rtol=1e-3) - print('elapsed seconds:', elapsed_seconds) - - -if __name__ == '__main__': - main() diff --git a/kaldifeat/python/tests/test_kaldifeat.py b/kaldifeat/python/tests/test_kaldifeat.py new file mode 100755 index 0000000..7e8c07d --- /dev/null +++ b/kaldifeat/python/tests/test_kaldifeat.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +from pathlib import Path +cur_dir = Path(__file__).resolve().parent +kaldi_feat_dir = cur_dir.parent.parent.parent + +import sys +sys.path.insert(0, f'{kaldi_feat_dir}/build/lib') + +import numpy as np +import soundfile as sf +import torch + +import _kaldifeat + + +def read_ark_txt() -> torch.Tensor: + test_data_dir = cur_dir / 'test_data' + filename = test_data_dir / 'abc.txt' + features = [] + with open(filename) as f: + for line in f: + if '[' in line: continue + line = line.strip('').split() + data = [float(d) for d in line if d != ']'] + features.append(data) + ans = torch.tensor(features) + return ans + + +def parse_str(s) -> torch.Tensor: + ''' + Args: + s: + It consists of several lines. Each line contains several numbers + separated by spaces. + ''' + ans = [] + for line in s.strip().split('\n'): + data = [float(d) for d in line.strip().split()] + ans.append(data) + return torch.tensor(ans) + + +def read_wave() -> torch.Tensor: + test_data_dir = cur_dir / 'test_data' + filename = test_data_dir / 'abc.wav' + with sf.SoundFile(filename) as sf_desc: + sampling_rate = sf_desc.samplerate + assert sampling_rate == 16000 + data = sf_desc.read(dtype=np.float32, always_2d=False) + data *= 32768 + return torch.from_numpy(data) + + +def test_and_benchmark_default_parameters(): + fbank_opts = _kaldifeat.FbankOptions() + fbank_opts.frame_opts.dither = 0 + + data = read_wave() + + ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time( + data, fbank_opts) + + expected = read_ark_txt() + assert torch.allclose(ans, expected, rtol=1e-3) + print('elapsed seconds:', elapsed_seconds) + + +def test_use_energy_htk_compat_true(): + fbank_opts = _kaldifeat.FbankOptions() + fbank_opts.frame_opts.dither = 0 + fbank_opts.use_energy = True + fbank_opts.htk_compat = True + + data = read_wave() + + ans = _kaldifeat.compute(data, fbank_opts) + + # ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:abc.scp ark,t:abc.txt + # the first 3 rows are: + expected_str = ''' + 15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303 25.38995 + 15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113 25.38996 + 15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073 25.38995 + ''' + expected = parse_str(expected_str) + assert torch.allclose(ans[:3, :], expected, rtol=1e-3) + + +def test_use_energy_htk_compat_false(): + fbank_opts = _kaldifeat.FbankOptions() + fbank_opts.frame_opts.dither = 0 + fbank_opts.use_energy = True + fbank_opts.htk_compat = False + + data = read_wave() + + ans = _kaldifeat.compute(data, fbank_opts) + + # ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=0 scp:abc.scp ark,t:abc.txt + # the first 3 rows are: + expected_str = ''' + 25.38995 15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303 + 25.38996 15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113 + 25.38995 15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073 + ''' + expected = parse_str(expected_str) + assert torch.allclose(ans[:3, :], expected, rtol=1e-3) + + +def main(): + test_and_benchmark_default_parameters() + test_use_energy_htk_compat_true() + test_use_energy_htk_compat_false() + + +if __name__ == '__main__': + main() diff --git a/kaldifeat/python/tests/test_options.py b/kaldifeat/python/tests/test_options.py new file mode 100755 index 0000000..0d9d1ea --- /dev/null +++ b/kaldifeat/python/tests/test_options.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +from pathlib import Path +cur_dir = Path(__file__).resolve().parent +kaldi_feat_dir = cur_dir.parent.parent.parent + +import sys +sys.path.insert(0, f'{kaldi_feat_dir}/build/lib') + +import torch +import _kaldifeat + + +def test_frame_extraction_options(): + opts = _kaldifeat.FrameExtractionOptions() + opts.samp_freq = 220500 + opts.frame_shift_ms = 15 + opts.frame_length_ms = 40 + opts.dither = 0.1 + opts.preemph_coeff = 0.98 + opts.remove_dc_offset = False + opts.window_type = 'hanning' + opts.round_to_power_of_two = False + opts.blackman_coeff = 0.422 + opts.snip_edges = False + print(opts) + + +def test_mel_banks_options(): + opts = _kaldifeat.MelBanksOptions() + opts.num_bins = 23 + opts.low_freq = 21 + opts.high_freq = 8000 + opts.vtln_low = 101 + opts.vtln_high = -501 + opts.debug_mel = True + opts.htk_mode = True + print(opts) + + +def test_fbank_options(): + opts = _kaldifeat.FbankOptions() + frame_opts = opts.frame_opts + mel_opts = opts.mel_opts + + opts.energy_floor = 0 + opts.htk_compat = False + opts.raw_energy = True + opts.use_energy = False + opts.use_log_fbank = True + opts.use_power = True + + frame_opts.blackman_coeff = 0.42 + frame_opts.dither = 1 + frame_opts.frame_length_ms = 25 + frame_opts.frame_shift_ms = 10 + frame_opts.preemph_coeff = 0.97 + frame_opts.remove_dc_offset = True + frame_opts.round_to_power_of_two = True + frame_opts.samp_freq = 16000 + frame_opts.snip_edges = True + frame_opts.window_type = 'povey' + + mel_opts.debug_mel = True + mel_opts.high_freq = 0 + mel_opts.low_freq = 20 + mel_opts.num_bins = 23 + mel_opts.vtln_high = -500 + mel_opts.vtln_low = 100 + + print(opts) + + +def main(): + # test_frame_extraction_options() + # test_mel_banks_options() + test_fbank_options() + + +if __name__ == '__main__': + main()