diff --git a/kaldifeat/python/csrc/mel-computations.cc b/kaldifeat/python/csrc/mel-computations.cc index 572f0ba..77e692b 100644 --- a/kaldifeat/python/csrc/mel-computations.cc +++ b/kaldifeat/python/csrc/mel-computations.cc @@ -7,6 +7,7 @@ #include #include "kaldifeat/csrc/mel-computations.h" +#include "kaldifeat/python/csrc/utils.h" namespace kaldifeat { @@ -22,7 +23,12 @@ static void PybindMelBanksOptions(py::module &m) { .def_readwrite("debug_mel", &PyClass::debug_mel) .def_readwrite("htk_mode", &PyClass::htk_mode) .def("__str__", - [](const PyClass &self) -> std::string { return self.ToString(); }); + [](const PyClass &self) -> std::string { return self.ToString(); }) + .def("as_dict", + [](const PyClass &self) -> py::dict { return AsDict(self); }) + .def_static("from_dict", [](py::dict dict) -> PyClass { + return MelBanksOptionsFromDict(dict); + }); } void PybindMelComputations(py::module &m) { PybindMelBanksOptions(m); } diff --git a/kaldifeat/python/csrc/utils.cc b/kaldifeat/python/csrc/utils.cc index c50a4a0..b298e61 100644 --- a/kaldifeat/python/csrc/utils.cc +++ b/kaldifeat/python/csrc/utils.cc @@ -17,6 +17,7 @@ namespace kaldifeat { FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict) { FrameExtractionOptions opts; + FROM_DICT(float_, samp_freq); FROM_DICT(float_, frame_shift_ms); FROM_DICT(float_, frame_length_ms); @@ -27,11 +28,13 @@ FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict) { FROM_DICT(bool_, round_to_power_of_two); FROM_DICT(float_, blackman_coeff); FROM_DICT(bool_, snip_edges); + return opts; } py::dict AsDict(const FrameExtractionOptions &opts) { py::dict dict; + AS_DICT(samp_freq); AS_DICT(frame_shift_ms); AS_DICT(frame_length_ms); @@ -45,6 +48,33 @@ py::dict AsDict(const FrameExtractionOptions &opts) { return dict; } +MelBanksOptions MelBanksOptionsFromDict(py::dict dict) { + MelBanksOptions opts; + + FROM_DICT(int_, num_bins); + FROM_DICT(float_, low_freq); + FROM_DICT(float_, high_freq); + FROM_DICT(float_, vtln_low); + FROM_DICT(float_, vtln_high); + FROM_DICT(bool_, debug_mel); + FROM_DICT(bool_, htk_mode); + + return opts; +} +py::dict AsDict(const MelBanksOptions &opts) { + py::dict dict; + + AS_DICT(num_bins); + AS_DICT(low_freq); + AS_DICT(high_freq); + AS_DICT(vtln_low); + AS_DICT(vtln_high); + AS_DICT(debug_mel); + AS_DICT(htk_mode); + + return dict; +} + #undef FROM_DICT #undef AS_DICT diff --git a/kaldifeat/python/csrc/utils.h b/kaldifeat/python/csrc/utils.h index 5d87a0e..5b39402 100644 --- a/kaldifeat/python/csrc/utils.h +++ b/kaldifeat/python/csrc/utils.h @@ -6,13 +6,29 @@ #define KALDIFEAT_PYTHON_CSRC_UTILS_H_ #include "kaldifeat/csrc/feature-window.h" +#include "kaldifeat/csrc/mel-computations.h" #include "kaldifeat/python/csrc/kaldifeat.h" +/* + * This file contains code about `from_dict` and + * `to_dict` for various options in kaldifeat. + * + * Regarding `from_dict`, users don't need to provide + * all the fields in the options. If some fields + * are not provided, it just uses the default one. + * + * If the provided dict in `from_dict` is empty, + * all fields use their default values. + */ + namespace kaldifeat { FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict); py::dict AsDict(const FrameExtractionOptions &opts); +MelBanksOptions MelBanksOptionsFromDict(py::dict dict); +py::dict AsDict(const MelBanksOptions &opts); + } // namespace kaldifeat #endif // KALDIFEAT_PYTHON_CSRC_UTILS_H_ diff --git a/kaldifeat/python/tests/CMakeLists.txt b/kaldifeat/python/tests/CMakeLists.txt index 0a0cb08..93e81ae 100644 --- a/kaldifeat/python/tests/CMakeLists.txt +++ b/kaldifeat/python/tests/CMakeLists.txt @@ -23,6 +23,7 @@ set(py_test_files test_spectrogram.py test_options.py test_frame_extraction_options.py + test_mel_bank_options.py ) foreach(source IN LISTS py_test_files) diff --git a/kaldifeat/python/tests/test_mel_bank_options.py b/kaldifeat/python/tests/test_mel_bank_options.py new file mode 100755 index 0000000..bb2924f --- /dev/null +++ b/kaldifeat/python/tests/test_mel_bank_options.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +import kaldifeat + + +def test_default(): + opts = kaldifeat.MelBanksOptions() + assert opts.num_bins == 25 + assert opts.low_freq == 20 + assert opts.high_freq == 0 + assert opts.vtln_low == 100 + assert opts.vtln_high == -500 + assert opts.debug_mel is False + assert opts.htk_mode is False + + +def test_set_get(): + opts = kaldifeat.MelBanksOptions() + opts.num_bins = 100 + assert opts.num_bins == 100 + + opts.low_freq = 22 + assert opts.low_freq == 22 + + opts.high_freq = 1 + assert opts.high_freq == 1 + + opts.vtln_low = 101 + assert opts.vtln_low == 101 + + opts.vtln_high = -100 + assert opts.vtln_high == -100 + + opts.debug_mel = True + assert opts.debug_mel is True + + opts.htk_mode = True + assert opts.htk_mode is True + + +def test_from_empty_dict(): + opts = kaldifeat.MelBanksOptions.from_dict({}) + opts2 = kaldifeat.MelBanksOptions() + + assert str(opts) == str(opts2) + + +def test_from_dict_partial(): + d = {"num_bins": 10, "debug_mel": True} + + opts = kaldifeat.MelBanksOptions.from_dict(d) + + opts2 = kaldifeat.MelBanksOptions() + assert str(opts) != str(opts2) + + opts2.num_bins = 10 + assert str(opts) != str(opts2) + + opts2.debug_mel = True + assert str(opts) == str(opts2) + + opts2.debug_mel = False + assert str(opts) != str(opts2) + + +def test_from_dict_full_and_as_dict(): + opts = kaldifeat.MelBanksOptions() + opts.num_bins = 80 + opts.vtln_high = 2 + + d = opts.as_dict() + for key, value in d.items(): + assert value == getattr(opts, key) + + opts2 = kaldifeat.MelBanksOptions.from_dict(d) + assert str(opts2) == str(opts) + + d["htk_mode"] = True + opts3 = kaldifeat.MelBanksOptions.from_dict(d) + assert opts3.htk_mode is True + + +def main(): + test_default() + test_set_get() + test_from_empty_dict() + test_from_dict_partial() + test_from_dict_full_and_as_dict() + + +if __name__ == "__main__": + main() diff --git a/kaldifeat/python/tests/test_options.py b/kaldifeat/python/tests/test_options.py index 2373779..92375f5 100755 --- a/kaldifeat/python/tests/test_options.py +++ b/kaldifeat/python/tests/test_options.py @@ -16,18 +16,6 @@ sys.path.insert(0, f"{kaldi_feat_dir}/build/lib") import kaldifeat -def test_mel_banks_options(): - opts = kaldifeat.MelBanksOptions() - opts.num_bins = 23 - opts.low_freq = 21 - opts.high_freq = 8000 - opts.vtln_low = 101 - opts.vtln_high = -501 - opts.debug_mel = True - opts.htk_mode = True - print(opts) - - def test_fbank_options(): opts = kaldifeat.FbankOptions() frame_opts = opts.frame_opts @@ -153,7 +141,6 @@ def test_plp_options(): def main(): - test_mel_banks_options() test_fbank_options() test_mfcc_options() test_spectogram_options()