diff --git a/kaldifeat/python/csrc/feature-fbank.cc b/kaldifeat/python/csrc/feature-fbank.cc index 8d8e8f0..7b3e31a 100644 --- a/kaldifeat/python/csrc/feature-fbank.cc +++ b/kaldifeat/python/csrc/feature-fbank.cc @@ -37,9 +37,12 @@ static void PybindFbankOptions(py::module &m) { [](const PyClass &self) -> std::string { return self.ToString(); }) .def("as_dict", [](const PyClass &self) -> py::dict { return AsDict(self); }) - .def_static("from_dict", [](py::dict dict) -> PyClass { - return FbankOptionsFromDict(dict); - }); + .def_static( + "from_dict", + [](py::dict dict) -> PyClass { return FbankOptionsFromDict(dict); }) + .def(py::pickle( + [](const PyClass &self) -> py::dict { return AsDict(self); }, + [](py::dict dict) -> PyClass { return FbankOptionsFromDict(dict); })); } static void PybindFbank(py::module &m) { @@ -49,7 +52,14 @@ static void PybindFbank(py::module &m) { .def("dim", &PyClass::Dim) .def_property_readonly("options", &PyClass::GetOptions) .def("compute_features", &PyClass::ComputeFeatures, py::arg("wave"), - py::arg("vtln_warp")); + py::arg("vtln_warp")) + .def(py::pickle( + [](const PyClass &self) -> py::dict { + return AsDict(self.GetOptions()); + }, + [](py::dict dict) -> std::unique_ptr { + return std::make_unique(FbankOptionsFromDict(dict)); + })); } void PybindFeatureFbank(py::module &m) { diff --git a/kaldifeat/python/csrc/feature-mfcc.cc b/kaldifeat/python/csrc/feature-mfcc.cc index b43930c..efa7b7b 100644 --- a/kaldifeat/python/csrc/feature-mfcc.cc +++ b/kaldifeat/python/csrc/feature-mfcc.cc @@ -37,9 +37,12 @@ void PybindMfccOptions(py::module &m) { [](const PyClass &self) -> std::string { return self.ToString(); }) .def("as_dict", [](const PyClass &self) -> py::dict { return AsDict(self); }) - .def_static("from_dict", [](py::dict dict) -> PyClass { - return MfccOptionsFromDict(dict); - }); + .def_static( + "from_dict", + [](py::dict dict) -> PyClass { return MfccOptionsFromDict(dict); }) + .def(py::pickle( + [](const PyClass &self) -> py::dict { return AsDict(self); }, + [](py::dict dict) -> PyClass { return MfccOptionsFromDict(dict); })); } static void PybindMfcc(py::module &m) { @@ -49,7 +52,14 @@ static void PybindMfcc(py::module &m) { .def("dim", &PyClass::Dim) .def_property_readonly("options", &PyClass::GetOptions) .def("compute_features", &PyClass::ComputeFeatures, py::arg("wave"), - py::arg("vtln_warp")); + py::arg("vtln_warp")) + .def(py::pickle( + [](const PyClass &self) -> py::dict { + return AsDict(self.GetOptions()); + }, + [](py::dict dict) -> std::unique_ptr { + return std::make_unique(MfccOptionsFromDict(dict)); + })); } void PybindFeatureMfcc(py::module &m) { diff --git a/kaldifeat/python/csrc/feature-plp.cc b/kaldifeat/python/csrc/feature-plp.cc index ef68e2c..5db6417 100644 --- a/kaldifeat/python/csrc/feature-plp.cc +++ b/kaldifeat/python/csrc/feature-plp.cc @@ -40,9 +40,12 @@ void PybindPlpOptions(py::module &m) { [](const PyClass &self) -> std::string { return self.ToString(); }) .def("as_dict", [](const PyClass &self) -> py::dict { return AsDict(self); }) - .def_static("from_dict", [](py::dict dict) -> PyClass { - return PlpOptionsFromDict(dict); - }); + .def_static( + "from_dict", + [](py::dict dict) -> PyClass { return PlpOptionsFromDict(dict); }) + .def(py::pickle( + [](const PyClass &self) -> py::dict { return AsDict(self); }, + [](py::dict dict) -> PyClass { return PlpOptionsFromDict(dict); })); } static void PybindPlp(py::module &m) { @@ -52,7 +55,14 @@ static void PybindPlp(py::module &m) { .def("dim", &PyClass::Dim) .def_property_readonly("options", &PyClass::GetOptions) .def("compute_features", &PyClass::ComputeFeatures, py::arg("wave"), - py::arg("vtln_warp")); + py::arg("vtln_warp")) + .def(py::pickle( + [](const PyClass &self) -> py::dict { + return AsDict(self.GetOptions()); + }, + [](py::dict dict) -> std::unique_ptr { + return std::make_unique(PlpOptionsFromDict(dict)); + })); } void PybindFeaturePlp(py::module &m) { diff --git a/kaldifeat/python/csrc/feature-spectrogram.cc b/kaldifeat/python/csrc/feature-spectrogram.cc index f752ebe..9e68529 100644 --- a/kaldifeat/python/csrc/feature-spectrogram.cc +++ b/kaldifeat/python/csrc/feature-spectrogram.cc @@ -34,9 +34,15 @@ static void PybindSpectrogramOptions(py::module &m) { [](const PyClass &self) -> std::string { return self.ToString(); }) .def("as_dict", [](const PyClass &self) -> py::dict { return AsDict(self); }) - .def_static("from_dict", [](py::dict dict) -> PyClass { - return SpectrogramOptionsFromDict(dict); - }); + .def_static("from_dict", + [](py::dict dict) -> PyClass { + return SpectrogramOptionsFromDict(dict); + }) + .def(py::pickle( + [](const PyClass &self) -> py::dict { return AsDict(self); }, + [](py::dict dict) -> PyClass { + return SpectrogramOptionsFromDict(dict); + })); } static void PybindSpectrogram(py::module &m) { @@ -46,7 +52,14 @@ static void PybindSpectrogram(py::module &m) { .def("dim", &PyClass::Dim) .def_property_readonly("options", &PyClass::GetOptions) .def("compute_features", &PyClass::ComputeFeatures, py::arg("wave"), - py::arg("vtln_warp")); + py::arg("vtln_warp")) + .def(py::pickle( + [](const PyClass &self) -> py::dict { + return AsDict(self.GetOptions()); + }, + [](py::dict dict) -> std::unique_ptr { + return std::make_unique(SpectrogramOptionsFromDict(dict)); + })); } void PybindFeatureSpectrogram(py::module &m) { diff --git a/kaldifeat/python/tests/test_fbank.py b/kaldifeat/python/tests/test_fbank.py index 2a9294c..9e6f8a3 100755 --- a/kaldifeat/python/tests/test_fbank.py +++ b/kaldifeat/python/tests/test_fbank.py @@ -2,6 +2,7 @@ # Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) +import pickle from pathlib import Path import torch @@ -156,6 +157,20 @@ def test_fbank_batch(): assert torch.allclose(features[1], features1) +def test_pickle(): + for device in get_devices(): + opts = kaldifeat.FbankOptions() + opts.use_energy = True + opts.use_power = False + opts.device = device + + fbank = kaldifeat.Fbank(opts) + data = pickle.dumps(fbank) + fbank2 = pickle.loads(data) + + assert str(fbank.opts) == str(fbank2.opts) + + if __name__ == "__main__": test_fbank_default() test_fbank_htk() @@ -164,3 +179,4 @@ if __name__ == "__main__": test_fbank_40_bins_no_snip_edges() test_fbank_chunk() test_fbank_batch() + test_pickle() diff --git a/kaldifeat/python/tests/test_fbank_options.py b/kaldifeat/python/tests/test_fbank_options.py index fb00054..f2fffdc 100755 --- a/kaldifeat/python/tests/test_fbank_options.py +++ b/kaldifeat/python/tests/test_fbank_options.py @@ -3,6 +3,8 @@ # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +import pickle + import torch import kaldifeat @@ -176,6 +178,21 @@ def test_from_dict_full_and_as_dict(): assert opts3.device == torch.device("cuda", 2) +def test_pickle(): + opts = kaldifeat.FbankOptions() + opts.use_energy = True + opts.use_power = False + opts.device = torch.device("cuda", 1) + + opts.frame_opts.samp_freq = 44100 + opts.mel_opts.num_bins = 100 + + data = pickle.dumps(opts) + + opts2 = pickle.loads(data) + assert str(opts) == str(opts2) + + def main(): test_default() test_set_get() @@ -184,6 +201,7 @@ def main(): test_from_empty_dict() test_from_dict_partial() test_from_dict_full_and_as_dict() + test_pickle() if __name__ == "__main__": diff --git a/kaldifeat/python/tests/test_mfcc.py b/kaldifeat/python/tests/test_mfcc.py index 8d1797e..33407b5 100755 --- a/kaldifeat/python/tests/test_mfcc.py +++ b/kaldifeat/python/tests/test_mfcc.py @@ -2,6 +2,7 @@ # Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) +import pickle from pathlib import Path import torch @@ -46,6 +47,21 @@ def test_mfcc_no_snip_edges(): assert torch.allclose(features.cpu(), gt, rtol=1e-1) +def test_pickle(): + for device in get_devices(): + opts = kaldifeat.MfccOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + + mfcc = kaldifeat.Mfcc(opts) + data = pickle.dumps(mfcc) + mfcc2 = pickle.loads(data) + + assert str(mfcc.opts) == str(mfcc2.opts) + + if __name__ == "__main__": test_mfcc_default() test_mfcc_no_snip_edges() + test_pickle() diff --git a/kaldifeat/python/tests/test_mfcc_options.py b/kaldifeat/python/tests/test_mfcc_options.py index 310ff93..cef03ab 100755 --- a/kaldifeat/python/tests/test_mfcc_options.py +++ b/kaldifeat/python/tests/test_mfcc_options.py @@ -3,6 +3,8 @@ # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +import pickle + import torch import kaldifeat @@ -180,6 +182,28 @@ def test_from_dict_full_and_as_dict(): assert opts3.device == torch.device("cuda", 10) +def test_pickle(): + opts = kaldifeat.MfccOptions() + opts.num_ceps = 222 + opts.use_energy = False + opts.cepstral_lifter = 21 + opts.htk_compat = True + opts.device = torch.device("cuda", 3) + + opts.frame_opts.samp_freq = 44100 + opts.frame_opts.frame_length_ms = 1 + opts.frame_opts.dither = 0.5 + + opts.mel_opts.num_bins = 100 + opts.mel_opts.low_freq = 22 + opts.mel_opts.vtln_high = -100 + + data = pickle.dumps(opts) + + opts2 = pickle.loads(data) + assert str(opts) == str(opts2) + + def main(): test_default() test_set_get() @@ -188,6 +212,7 @@ def main(): test_from_empty_dict() test_from_dict_partial() test_from_dict_full_and_as_dict() + test_pickle() if __name__ == "__main__": diff --git a/kaldifeat/python/tests/test_plp.py b/kaldifeat/python/tests/test_plp.py index 7b47e5f..4f20452 100755 --- a/kaldifeat/python/tests/test_plp.py +++ b/kaldifeat/python/tests/test_plp.py @@ -2,6 +2,7 @@ # Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) +import pickle from pathlib import Path import torch @@ -65,7 +66,22 @@ def test_plp_htk_10_ceps(): assert torch.allclose(features.cpu(), gt, atol=1e-1) +def test_pickle(): + for device in get_devices(): + opts = kaldifeat.PlpOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + + plp = kaldifeat.Plp(opts) + data = pickle.dumps(plp) + plp2 = pickle.loads(data) + + assert str(plp.opts) == str(plp2.opts) + + if __name__ == "__main__": test_plp_default() test_plp_no_snip_edges() test_plp_htk_10_ceps() + test_pickle() diff --git a/kaldifeat/python/tests/test_plp_options.py b/kaldifeat/python/tests/test_plp_options.py index dc87045..c30dd64 100755 --- a/kaldifeat/python/tests/test_plp_options.py +++ b/kaldifeat/python/tests/test_plp_options.py @@ -3,6 +3,8 @@ # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +import pickle + import torch import kaldifeat @@ -191,6 +193,27 @@ def test_from_dict_full_and_as_dict(): assert opts3.device == torch.device("cuda", 2) +def test_pickle(): + opts = kaldifeat.PlpOptions() + opts.lpc_order = 11 + opts.num_ceps = 1 + opts.use_energy = False + opts.compress_factor = 0.5 + opts.cepstral_lifter = 2 + opts.device = torch.device("cuda", 1) + + opts.frame_opts.samp_freq = 44100 + opts.frame_opts.snip_edges = False + + opts.mel_opts.num_bins = 100 + opts.mel_opts.high_freq = 1 + + data = pickle.dumps(opts) + + opts2 = pickle.loads(data) + assert str(opts) == str(opts2) + + def main(): test_default() test_set_get() @@ -199,6 +222,7 @@ def main(): test_from_empty_dict() test_from_dict_partial() test_from_dict_full_and_as_dict() + test_pickle() if __name__ == "__main__": diff --git a/kaldifeat/python/tests/test_spectrogram.py b/kaldifeat/python/tests/test_spectrogram.py index 387a6c4..7e5a106 100755 --- a/kaldifeat/python/tests/test_spectrogram.py +++ b/kaldifeat/python/tests/test_spectrogram.py @@ -2,6 +2,7 @@ # Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) +import pickle from pathlib import Path from utils import get_devices, read_ark_txt, read_wave @@ -50,6 +51,21 @@ def test_spectrogram_no_snip_edges(): print(features[1, 145:148], gt[1, 145:148]) # they are different +def test_pickle(): + for device in get_devices(): + opts = kaldifeat.SpectrogramOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + + spec = kaldifeat.Spectrogram(opts) + data = pickle.dumps(spec) + spec2 = pickle.loads(data) + + assert str(spec.opts) == str(spec2.opts) + + if __name__ == "__main__": test_spectrogram_default() test_spectrogram_no_snip_edges() + test_pickle() diff --git a/kaldifeat/python/tests/test_spectrogram_options.py b/kaldifeat/python/tests/test_spectrogram_options.py index d830300..34c8849 100755 --- a/kaldifeat/python/tests/test_spectrogram_options.py +++ b/kaldifeat/python/tests/test_spectrogram_options.py @@ -3,6 +3,8 @@ # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +import pickle + import torch import kaldifeat @@ -121,6 +123,21 @@ def test_from_dict_full_and_as_dict(): assert str(opts3) == str(opts) +def test_pickle(): + opts = kaldifeat.SpectrogramOptions() + opts.energy_floor = 1 + opts.raw_energy = False + opts.device = torch.device("cuda", 1) + + opts.frame_opts.samp_freq = 44100 + opts.frame_opts.snip_edges = False + + data = pickle.dumps(opts) + + opts2 = pickle.loads(data) + assert str(opts) == str(opts2) + + def main(): test_default() test_set_get() @@ -128,6 +145,7 @@ def main(): test_from_empty_dict() test_from_dict_partial() test_from_dict_full_and_as_dict() + test_pickle() if __name__ == "__main__":