Change the way how MelBankOptions is displayed

This commit is contained in:
Fangjun Kuang 2022-12-03 12:14:58 +08:00
parent f28857495d
commit 5b97eeadb5
4 changed files with 35 additions and 11 deletions

View File

@ -36,13 +36,14 @@ struct MelBanksOptions {
std::string ToString() const {
std::ostringstream os;
os << "num_bins: " << num_bins << "\n";
os << "low_freq: " << low_freq << "\n";
os << "high_freq: " << high_freq << "\n";
os << "vtln_low: " << vtln_low << "\n";
os << "vtln_high: " << vtln_high << "\n";
os << "debug_mel: " << debug_mel << "\n";
os << "htk_mode: " << htk_mode << "\n";
os << "MelBanksOptions(";
os << "num_bins=" << num_bins << ", ";
os << "low_freq=" << low_freq << ", ";
os << "high_freq=" << high_freq << ", ";
os << "vtln_low=" << vtln_low << ", ";
os << "vtln_high=" << vtln_high << ", ";
os << "debug_mel=" << (debug_mel ? "True" : "False") << ", ";
os << "htk_mode=" << (htk_mode ? "True" : "False") << ")";
return os.str();
}
};

View File

@ -4,6 +4,7 @@
#include "kaldifeat/python/csrc/mel-computations.h"
#include <memory>
#include <string>
#include "kaldifeat/csrc/mel-computations.h"
@ -14,7 +15,24 @@ namespace kaldifeat {
static void PybindMelBanksOptions(py::module &m) {
using PyClass = MelBanksOptions;
py::class_<PyClass>(m, "MelBanksOptions")
.def(py::init<>())
.def(py::init(
[](int32_t num_bins = 25, float low_freq = 20,
float high_freq = 0, float vtln_low = 100,
float vtln_high = -500,
bool debug_mel = false) -> std::unique_ptr<MelBanksOptions> {
auto opts = std::make_unique<MelBanksOptions>();
opts->num_bins = num_bins;
opts->low_freq = low_freq;
opts->high_freq = high_freq;
opts->vtln_low = vtln_low;
opts->vtln_high = vtln_high;
return opts;
}),
py::arg("num_bins") = 25, py::arg("low_freq") = 20,
py::arg("high_freq") = 0, py::arg("vtln_low") = 100,
py::arg("vtln_high") = -500, py::arg("debug_mel") = false)
.def_readwrite("num_bins", &PyClass::num_bins)
.def_readwrite("low_freq", &PyClass::low_freq)
.def_readwrite("high_freq", &PyClass::high_freq)

View File

@ -23,7 +23,9 @@ def test_default():
def test_set_get():
opts = kaldifeat.FrameExtractionOptions()
opts = kaldifeat.FrameExtractionOptions(samp_freq=22150)
assert opts.samp_freq == 22150
opts.samp_freq = 44100
assert opts.samp_freq == 44100

View File

@ -9,6 +9,7 @@ import kaldifeat
def test_default():
opts = kaldifeat.MelBanksOptions()
print(opts)
assert opts.num_bins == 25
assert opts.low_freq == 20
assert opts.high_freq == 0
@ -19,10 +20,12 @@ def test_default():
def test_set_get():
opts = kaldifeat.MelBanksOptions()
opts.num_bins = 100
opts = kaldifeat.MelBanksOptions(num_bins=100)
assert opts.num_bins == 100
opts.num_bins = 200
assert opts.num_bins == 200
opts.low_freq = 22
assert opts.low_freq == 22