From 6eb7a3b243819ec47368ac41e4bb6b45384b0bd3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 4 Nov 2021 10:40:58 +0800 Subject: [PATCH] Add pickle support to MelBanksOptions. --- kaldifeat/python/csrc/feature-window.cc | 9 +++++---- kaldifeat/python/csrc/mel-computations.cc | 12 +++++++++--- kaldifeat/python/tests/test_mel_bank_options.py | 13 +++++++++++++ 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/kaldifeat/python/csrc/feature-window.cc b/kaldifeat/python/csrc/feature-window.cc index d7af62c..0fd2ea8 100644 --- a/kaldifeat/python/csrc/feature-window.cc +++ b/kaldifeat/python/csrc/feature-window.cc @@ -40,10 +40,11 @@ static void PybindFrameExtractionOptions(py::module &m) { #endif .def("__str__", [](const PyClass &self) -> std::string { return self.ToString(); }) - .def(py::pickle([](const PyClass &self) { return AsDict(self); }, - [](py::dict dict) -> PyClass { - return FrameExtractionOptionsFromDict(dict); - })); + .def(py::pickle( + [](const PyClass &self) -> py::dict { return AsDict(self); }, + [](py::dict dict) -> PyClass { + return FrameExtractionOptionsFromDict(dict); + })); m.def("num_frames", &NumFrames, py::arg("num_samples"), py::arg("opts"), py::arg("flush") = true); diff --git a/kaldifeat/python/csrc/mel-computations.cc b/kaldifeat/python/csrc/mel-computations.cc index 77e692b..e8f1c31 100644 --- a/kaldifeat/python/csrc/mel-computations.cc +++ b/kaldifeat/python/csrc/mel-computations.cc @@ -26,9 +26,15 @@ static void PybindMelBanksOptions(py::module &m) { [](const PyClass &self) -> std::string { return self.ToString(); }) .def("as_dict", [](const PyClass &self) -> py::dict { return AsDict(self); }) - .def_static("from_dict", [](py::dict dict) -> PyClass { - return MelBanksOptionsFromDict(dict); - }); + .def_static("from_dict", + [](py::dict dict) -> PyClass { + return MelBanksOptionsFromDict(dict); + }) + .def(py::pickle( + [](const PyClass &self) -> py::dict { return AsDict(self); }, + [](py::dict dict) -> PyClass { + return MelBanksOptionsFromDict(dict); + })); } void PybindMelComputations(py::module &m) { PybindMelBanksOptions(m); } diff --git a/kaldifeat/python/tests/test_mel_bank_options.py b/kaldifeat/python/tests/test_mel_bank_options.py index bb2924f..70624a1 100755 --- a/kaldifeat/python/tests/test_mel_bank_options.py +++ b/kaldifeat/python/tests/test_mel_bank_options.py @@ -2,6 +2,8 @@ # # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +import pickle + import kaldifeat @@ -82,12 +84,23 @@ def test_from_dict_full_and_as_dict(): assert opts3.htk_mode is True +def test_pickle(): + opts = kaldifeat.MelBanksOptions() + opts.num_bins = 100 + opts.low_freq = 22 + data = pickle.dumps(opts) + + opts2 = pickle.loads(data) + assert str(opts) == str(opts2) + + def main(): test_default() test_set_get() test_from_empty_dict() test_from_dict_partial() test_from_dict_full_and_as_dict() + test_pickle() if __name__ == "__main__":