From 5b97eeadb57b96382ce4d90119a505498b26f8c5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 3 Dec 2022 12:14:58 +0800 Subject: [PATCH] Change the way how MelBankOptions is displayed --- kaldifeat/csrc/mel-computations.h | 15 +++++++------- kaldifeat/python/csrc/mel-computations.cc | 20 ++++++++++++++++++- .../tests/test_frame_extraction_options.py | 4 +++- .../python/tests/test_mel_bank_options.py | 7 +++++-- 4 files changed, 35 insertions(+), 11 deletions(-) diff --git a/kaldifeat/csrc/mel-computations.h b/kaldifeat/csrc/mel-computations.h index 0abfabd..9aa8d9c 100644 --- a/kaldifeat/csrc/mel-computations.h +++ b/kaldifeat/csrc/mel-computations.h @@ -36,13 +36,14 @@ struct MelBanksOptions { std::string ToString() const { std::ostringstream os; - os << "num_bins: " << num_bins << "\n"; - os << "low_freq: " << low_freq << "\n"; - os << "high_freq: " << high_freq << "\n"; - os << "vtln_low: " << vtln_low << "\n"; - os << "vtln_high: " << vtln_high << "\n"; - os << "debug_mel: " << debug_mel << "\n"; - os << "htk_mode: " << htk_mode << "\n"; + os << "MelBanksOptions("; + os << "num_bins=" << num_bins << ", "; + os << "low_freq=" << low_freq << ", "; + os << "high_freq=" << high_freq << ", "; + os << "vtln_low=" << vtln_low << ", "; + os << "vtln_high=" << vtln_high << ", "; + os << "debug_mel=" << (debug_mel ? "True" : "False") << ", "; + os << "htk_mode=" << (htk_mode ? "True" : "False") << ")"; return os.str(); } }; diff --git a/kaldifeat/python/csrc/mel-computations.cc b/kaldifeat/python/csrc/mel-computations.cc index e8f1c31..fa9544f 100644 --- a/kaldifeat/python/csrc/mel-computations.cc +++ b/kaldifeat/python/csrc/mel-computations.cc @@ -4,6 +4,7 @@ #include "kaldifeat/python/csrc/mel-computations.h" +#include #include #include "kaldifeat/csrc/mel-computations.h" @@ -14,7 +15,24 @@ namespace kaldifeat { static void PybindMelBanksOptions(py::module &m) { using PyClass = MelBanksOptions; py::class_(m, "MelBanksOptions") - .def(py::init<>()) + .def(py::init( + [](int32_t num_bins = 25, float low_freq = 20, + float high_freq = 0, float vtln_low = 100, + float vtln_high = -500, + bool debug_mel = false) -> std::unique_ptr { + auto opts = std::make_unique(); + + opts->num_bins = num_bins; + opts->low_freq = low_freq; + opts->high_freq = high_freq; + opts->vtln_low = vtln_low; + opts->vtln_high = vtln_high; + + return opts; + }), + py::arg("num_bins") = 25, py::arg("low_freq") = 20, + py::arg("high_freq") = 0, py::arg("vtln_low") = 100, + py::arg("vtln_high") = -500, py::arg("debug_mel") = false) .def_readwrite("num_bins", &PyClass::num_bins) .def_readwrite("low_freq", &PyClass::low_freq) .def_readwrite("high_freq", &PyClass::high_freq) diff --git a/kaldifeat/python/tests/test_frame_extraction_options.py b/kaldifeat/python/tests/test_frame_extraction_options.py index 27e5405..c4a2e73 100755 --- a/kaldifeat/python/tests/test_frame_extraction_options.py +++ b/kaldifeat/python/tests/test_frame_extraction_options.py @@ -23,7 +23,9 @@ def test_default(): def test_set_get(): - opts = kaldifeat.FrameExtractionOptions() + opts = kaldifeat.FrameExtractionOptions(samp_freq=22150) + assert opts.samp_freq == 22150 + opts.samp_freq = 44100 assert opts.samp_freq == 44100 diff --git a/kaldifeat/python/tests/test_mel_bank_options.py b/kaldifeat/python/tests/test_mel_bank_options.py index 70624a1..c064980 100755 --- a/kaldifeat/python/tests/test_mel_bank_options.py +++ b/kaldifeat/python/tests/test_mel_bank_options.py @@ -9,6 +9,7 @@ import kaldifeat def test_default(): opts = kaldifeat.MelBanksOptions() + print(opts) assert opts.num_bins == 25 assert opts.low_freq == 20 assert opts.high_freq == 0 @@ -19,10 +20,12 @@ def test_default(): def test_set_get(): - opts = kaldifeat.MelBanksOptions() - opts.num_bins = 100 + opts = kaldifeat.MelBanksOptions(num_bins=100) assert opts.num_bins == 100 + opts.num_bins = 200 + assert opts.num_bins == 200 + opts.low_freq = 22 assert opts.low_freq == 22