diff --git a/kaldifeat/python/csrc/feature-window.cc b/kaldifeat/python/csrc/feature-window.cc index 81cc275..d7af62c 100644 --- a/kaldifeat/python/csrc/feature-window.cc +++ b/kaldifeat/python/csrc/feature-window.cc @@ -12,39 +12,38 @@ namespace kaldifeat { static void PybindFrameExtractionOptions(py::module &m) { - py::class_(m, "FrameExtractionOptions") + using PyClass = FrameExtractionOptions; + py::class_(m, "FrameExtractionOptions") .def(py::init<>()) - .def_readwrite("samp_freq", &FrameExtractionOptions::samp_freq) - .def_readwrite("frame_shift_ms", &FrameExtractionOptions::frame_shift_ms) - .def_readwrite("frame_length_ms", - &FrameExtractionOptions::frame_length_ms) - .def_readwrite("dither", &FrameExtractionOptions::dither) - .def_readwrite("preemph_coeff", &FrameExtractionOptions::preemph_coeff) - .def_readwrite("remove_dc_offset", - &FrameExtractionOptions::remove_dc_offset) - .def_readwrite("window_type", &FrameExtractionOptions::window_type) - .def_readwrite("round_to_power_of_two", - &FrameExtractionOptions::round_to_power_of_two) - .def_readwrite("blackman_coeff", &FrameExtractionOptions::blackman_coeff) - .def_readwrite("snip_edges", &FrameExtractionOptions::snip_edges) + .def_readwrite("samp_freq", &PyClass::samp_freq) + .def_readwrite("frame_shift_ms", &PyClass::frame_shift_ms) + .def_readwrite("frame_length_ms", &PyClass::frame_length_ms) + .def_readwrite("dither", &PyClass::dither) + .def_readwrite("preemph_coeff", &PyClass::preemph_coeff) + .def_readwrite("remove_dc_offset", &PyClass::remove_dc_offset) + .def_readwrite("window_type", &PyClass::window_type) + .def_readwrite("round_to_power_of_two", &PyClass::round_to_power_of_two) + .def_readwrite("blackman_coeff", &PyClass::blackman_coeff) + .def_readwrite("snip_edges", &PyClass::snip_edges) .def("as_dict", - [](const FrameExtractionOptions &self) -> py::dict { - return AsDict(self); - }) + [](const PyClass &self) -> py::dict { return AsDict(self); }) .def_static("from_dict", - [](py::dict dict) -> FrameExtractionOptions { + [](py::dict dict) -> PyClass { return FrameExtractionOptionsFromDict(dict); }) #if 0 .def_readwrite("allow_downsample", - &FrameExtractionOptions::allow_downsample) - .def_readwrite("allow_upsample", &FrameExtractionOptions::allow_upsample) + &PyClass::allow_downsample) + .def_readwrite("allow_upsample", &PyClass::allow_upsample) .def_readwrite("max_feature_vectors", - &FrameExtractionOptions::max_feature_vectors) + &PyClass::max_feature_vectors) #endif - .def("__str__", [](const FrameExtractionOptions &self) -> std::string { - return self.ToString(); - }); + .def("__str__", + [](const PyClass &self) -> std::string { return self.ToString(); }) + .def(py::pickle([](const PyClass &self) { return AsDict(self); }, + [](py::dict dict) -> PyClass { + return FrameExtractionOptionsFromDict(dict); + })); m.def("num_frames", &NumFrames, py::arg("num_samples"), py::arg("opts"), py::arg("flush") = true); diff --git a/kaldifeat/python/csrc/utils.h b/kaldifeat/python/csrc/utils.h index 472d7c9..9ecac6d 100644 --- a/kaldifeat/python/csrc/utils.h +++ b/kaldifeat/python/csrc/utils.h @@ -15,7 +15,7 @@ /* * This file contains code about `from_dict` and - * `to_dict` for various options in kaldifeat. + * `as_dict` for various options in kaldifeat. * * Regarding `from_dict`, users don't need to provide * all the fields in the options. If some fields diff --git a/kaldifeat/python/tests/test_frame_extraction_options.py b/kaldifeat/python/tests/test_frame_extraction_options.py index 4fa90d9..511d0a7 100755 --- a/kaldifeat/python/tests/test_frame_extraction_options.py +++ b/kaldifeat/python/tests/test_frame_extraction_options.py @@ -2,6 +2,8 @@ # # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +import pickle + import kaldifeat @@ -94,12 +96,23 @@ def test_from_dict_full_and_as_dict(): assert opts3.window_type == "hanning" +def test_pickle(): + opts = kaldifeat.FrameExtractionOptions() + opts.samp_freq = 44100 + opts.dither = 5.5 + data = pickle.dumps(opts) + + opts2 = pickle.loads(data) + assert str(opts) == str(opts2) + + def main(): test_default() test_set_get() test_from_empty_dict() test_from_dict_partial() test_from_dict_full_and_as_dict() + test_pickle() if __name__ == "__main__":