diff --git a/kaldifeat/python/csrc/CMakeLists.txt b/kaldifeat/python/csrc/CMakeLists.txt index bd1ef25..affb69c 100644 --- a/kaldifeat/python/csrc/CMakeLists.txt +++ b/kaldifeat/python/csrc/CMakeLists.txt @@ -7,6 +7,7 @@ pybind11_add_module(_kaldifeat feature-window.cc kaldifeat.cc mel-computations.cc + utils.cc ) target_link_libraries(_kaldifeat PRIVATE kaldifeat_core) if(UNIX AND NOT APPLE) diff --git a/kaldifeat/python/csrc/feature-window.cc b/kaldifeat/python/csrc/feature-window.cc index 5d76688..81cc275 100644 --- a/kaldifeat/python/csrc/feature-window.cc +++ b/kaldifeat/python/csrc/feature-window.cc @@ -7,6 +7,7 @@ #include #include "kaldifeat/csrc/feature-window.h" +#include "kaldifeat/python/csrc/utils.h" namespace kaldifeat { @@ -26,6 +27,14 @@ static void PybindFrameExtractionOptions(py::module &m) { &FrameExtractionOptions::round_to_power_of_two) .def_readwrite("blackman_coeff", &FrameExtractionOptions::blackman_coeff) .def_readwrite("snip_edges", &FrameExtractionOptions::snip_edges) + .def("as_dict", + [](const FrameExtractionOptions &self) -> py::dict { + return AsDict(self); + }) + .def_static("from_dict", + [](py::dict dict) -> FrameExtractionOptions { + return FrameExtractionOptionsFromDict(dict); + }) #if 0 .def_readwrite("allow_downsample", &FrameExtractionOptions::allow_downsample) diff --git a/kaldifeat/python/csrc/utils.cc b/kaldifeat/python/csrc/utils.cc new file mode 100644 index 0000000..0476491 --- /dev/null +++ b/kaldifeat/python/csrc/utils.cc @@ -0,0 +1,47 @@ +// kaldifeat/python/csrc/utils.cc +// +// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +#include "kaldifeat/python/csrc/utils.h" + +#include "kaldifeat/csrc/feature-window.h" + +#define FROM_DICT(type, key) opts.key = py::type(dict[#key]) +#define AS_DICT(key) dict[#key] = opts.key + +namespace kaldifeat { + +FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict) { + FrameExtractionOptions opts; + FROM_DICT(float_, samp_freq); + FROM_DICT(float_, frame_shift_ms); + FROM_DICT(float_, frame_length_ms); + FROM_DICT(float_, dither); + FROM_DICT(float_, preemph_coeff); + FROM_DICT(bool_, remove_dc_offset); + FROM_DICT(str, window_type); + FROM_DICT(bool_, round_to_power_of_two); + FROM_DICT(float_, blackman_coeff); + FROM_DICT(bool_, snip_edges); + return opts; +} + +py::dict AsDict(const FrameExtractionOptions &opts) { + py::dict dict; + AS_DICT(samp_freq); + AS_DICT(frame_shift_ms); + AS_DICT(frame_length_ms); + AS_DICT(dither); + AS_DICT(preemph_coeff); + AS_DICT(remove_dc_offset); + AS_DICT(window_type); + AS_DICT(round_to_power_of_two); + AS_DICT(blackman_coeff); + AS_DICT(snip_edges); + return dict; +} + +#undef FROM_DICT +#undef AS_DICT + +} // namespace kaldifeat diff --git a/kaldifeat/python/csrc/utils.h b/kaldifeat/python/csrc/utils.h new file mode 100644 index 0000000..5d87a0e --- /dev/null +++ b/kaldifeat/python/csrc/utils.h @@ -0,0 +1,18 @@ +// kaldifeat/python/csrc/utils.h +// +// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +#ifndef KALDIFEAT_PYTHON_CSRC_UTILS_H_ +#define KALDIFEAT_PYTHON_CSRC_UTILS_H_ + +#include "kaldifeat/csrc/feature-window.h" +#include "kaldifeat/python/csrc/kaldifeat.h" + +namespace kaldifeat { + +FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict); +py::dict AsDict(const FrameExtractionOptions &opts); + +} // namespace kaldifeat + +#endif // KALDIFEAT_PYTHON_CSRC_UTILS_H_ diff --git a/kaldifeat/python/tests/test_options.py b/kaldifeat/python/tests/test_options.py index e048629..96e2f63 100755 --- a/kaldifeat/python/tests/test_options.py +++ b/kaldifeat/python/tests/test_options.py @@ -28,7 +28,19 @@ def test_frame_extraction_options(): opts.round_to_power_of_two = False opts.blackman_coeff = 0.422 opts.snip_edges = False - print(opts) + + opts_dict = opts.as_dict() + for key, value in opts_dict.items(): + assert value == getattr(opts, key) + + opts2 = kaldifeat.FrameExtractionOptions.from_dict(opts_dict) + + for key, value in opts_dict.items(): + assert value == getattr(opts2, key) + + assert str(opts) == str(opts2) + + assert opts_dict == opts2.as_dict() def test_mel_banks_options():