diff --git a/kaldifeat/python/csrc/feature-mfcc.cc b/kaldifeat/python/csrc/feature-mfcc.cc index 2d48481..b43930c 100644 --- a/kaldifeat/python/csrc/feature-mfcc.cc +++ b/kaldifeat/python/csrc/feature-mfcc.cc @@ -7,6 +7,7 @@ #include #include "kaldifeat/csrc/feature-mfcc.h" +#include "kaldifeat/python/csrc/utils.h" namespace kaldifeat { @@ -33,7 +34,12 @@ void PybindMfccOptions(py::module &m) { self.device = torch::Device(s); }) .def("__str__", - [](const PyClass &self) -> std::string { return self.ToString(); }); + [](const PyClass &self) -> std::string { return self.ToString(); }) + .def("as_dict", + [](const PyClass &self) -> py::dict { return AsDict(self); }) + .def_static("from_dict", [](py::dict dict) -> PyClass { + return MfccOptionsFromDict(dict); + }); } static void PybindMfcc(py::module &m) { diff --git a/kaldifeat/python/csrc/utils.cc b/kaldifeat/python/csrc/utils.cc index 81439b1..85f9787 100644 --- a/kaldifeat/python/csrc/utils.cc +++ b/kaldifeat/python/csrc/utils.cc @@ -119,6 +119,49 @@ py::dict AsDict(const FbankOptions &opts) { return dict; } +MfccOptions MfccOptionsFromDict(py::dict dict) { + MfccOptions opts; + + if (dict.contains("frame_opts")) { + opts.frame_opts = FrameExtractionOptionsFromDict(dict["frame_opts"]); + } + + if (dict.contains("mel_opts")) { + opts.mel_opts = MelBanksOptionsFromDict(dict["mel_opts"]); + } + + FROM_DICT(int_, num_ceps); + FROM_DICT(bool_, use_energy); + FROM_DICT(float_, energy_floor); + FROM_DICT(bool_, raw_energy); + FROM_DICT(float_, cepstral_lifter); + FROM_DICT(bool_, htk_compat); + + if (dict.contains("device")) { + opts.device = torch::Device(std::string(py::str(dict["device"]))); + } + + return opts; +} + +py::dict AsDict(const MfccOptions &opts) { + py::dict dict; + + dict["frame_opts"] = AsDict(opts.frame_opts); + dict["mel_opts"] = AsDict(opts.mel_opts); + AS_DICT(num_ceps); + AS_DICT(use_energy); + AS_DICT(energy_floor); + AS_DICT(raw_energy); + AS_DICT(cepstral_lifter); + AS_DICT(htk_compat); + + auto torch_device = py::module_::import("torch").attr("device"); + dict["device"] = torch_device(opts.device.str()); + + return dict; +} + #undef FROM_DICT #undef AS_DICT diff --git a/kaldifeat/python/csrc/utils.h b/kaldifeat/python/csrc/utils.h index b683e63..5ddc0d7 100644 --- a/kaldifeat/python/csrc/utils.h +++ b/kaldifeat/python/csrc/utils.h @@ -6,6 +6,7 @@ #define KALDIFEAT_PYTHON_CSRC_UTILS_H_ #include "kaldifeat/csrc/feature-fbank.h" +#include "kaldifeat/csrc/feature-mfcc.h" #include "kaldifeat/csrc/feature-window.h" #include "kaldifeat/csrc/mel-computations.h" #include "kaldifeat/python/csrc/kaldifeat.h" @@ -33,6 +34,9 @@ py::dict AsDict(const MelBanksOptions &opts); FbankOptions FbankOptionsFromDict(py::dict dict); py::dict AsDict(const FbankOptions &opts); +MfccOptions MfccOptionsFromDict(py::dict dict); +py::dict AsDict(const MfccOptions &opts); + } // namespace kaldifeat #endif // KALDIFEAT_PYTHON_CSRC_UTILS_H_ diff --git a/kaldifeat/python/tests/CMakeLists.txt b/kaldifeat/python/tests/CMakeLists.txt index 02bad90..b5372c1 100644 --- a/kaldifeat/python/tests/CMakeLists.txt +++ b/kaldifeat/python/tests/CMakeLists.txt @@ -22,6 +22,7 @@ set(py_test_files test_frame_extraction_options.py test_mel_bank_options.py test_mfcc.py + test_mfcc_options.py test_options.py test_plp.py test_spectrogram.py diff --git a/kaldifeat/python/tests/test_fbank_options.py b/kaldifeat/python/tests/test_fbank_options.py index 562bcac..fb00054 100755 --- a/kaldifeat/python/tests/test_fbank_options.py +++ b/kaldifeat/python/tests/test_fbank_options.py @@ -143,7 +143,8 @@ def test_from_dict_partial(): assert opts.mel_opts.vtln_low == 1 assert opts.frame_opts.window_type == "hanning" - opts = kaldifeat.MelBanksOptions.from_dict(d) + mel_opts = kaldifeat.MelBanksOptions.from_dict(d["mel_opts"]) + assert str(opts.mel_opts) == str(mel_opts) def test_from_dict_full_and_as_dict(): diff --git a/kaldifeat/python/tests/test_mfcc_options.py b/kaldifeat/python/tests/test_mfcc_options.py new file mode 100755 index 0000000..310ff93 --- /dev/null +++ b/kaldifeat/python/tests/test_mfcc_options.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + + +import torch + +import kaldifeat + + +def test_default(): + opts = kaldifeat.MfccOptions() + + assert opts.frame_opts.samp_freq == 16000 + assert opts.frame_opts.frame_shift_ms == 10.0 + assert opts.frame_opts.frame_length_ms == 25.0 + assert opts.frame_opts.dither == 1.0 + assert abs(opts.frame_opts.preemph_coeff - 0.97) < 1e-6 + assert opts.frame_opts.remove_dc_offset is True + assert opts.frame_opts.window_type == "povey" + assert opts.frame_opts.round_to_power_of_two is True + assert abs(opts.frame_opts.blackman_coeff - 0.42) < 1e-6 + assert opts.frame_opts.snip_edges is True + + assert opts.mel_opts.num_bins == 23 + assert opts.mel_opts.low_freq == 20 + assert opts.mel_opts.high_freq == 0 + assert opts.mel_opts.vtln_low == 100 + assert opts.mel_opts.vtln_high == -500 + assert opts.mel_opts.debug_mel is False + assert opts.mel_opts.htk_mode is False + + assert opts.num_ceps == 13 + assert opts.use_energy is True + assert opts.energy_floor == 0 + assert opts.raw_energy is True + assert opts.cepstral_lifter == 22.0 + assert opts.htk_compat is False + + assert opts.device.type == "cpu" + + +def test_set_get(): + opts = kaldifeat.MfccOptions() + opts.num_ceps = 22 + assert opts.num_ceps == 22 + + opts.use_energy = False + assert opts.use_energy is False + + opts.energy_floor = 1 + assert opts.energy_floor == 1 + + opts.raw_energy = False + assert opts.raw_energy is False + + opts.cepstral_lifter = 21 + assert opts.cepstral_lifter == 21 + + opts.htk_compat = True + assert opts.htk_compat is True + + opts.device = torch.device("cuda", 1) + assert opts.device.type == "cuda" + assert opts.device.index == 1 + + +def test_set_get_frame_opts(): + opts = kaldifeat.MfccOptions() + + opts.frame_opts.samp_freq = 44100 + assert opts.frame_opts.samp_freq == 44100 + + opts.frame_opts.frame_shift_ms = 20.5 + assert opts.frame_opts.frame_shift_ms == 20.5 + + opts.frame_opts.frame_length_ms = 1 + assert opts.frame_opts.frame_length_ms == 1 + + opts.frame_opts.dither = 0.5 + assert opts.frame_opts.dither == 0.5 + + opts.frame_opts.preemph_coeff = 0.25 + assert opts.frame_opts.preemph_coeff == 0.25 + + opts.frame_opts.remove_dc_offset = False + assert opts.frame_opts.remove_dc_offset is False + + opts.frame_opts.window_type = "hanning" + assert opts.frame_opts.window_type == "hanning" + + opts.frame_opts.round_to_power_of_two = False + assert opts.frame_opts.round_to_power_of_two is False + + opts.frame_opts.blackman_coeff = 0.25 + assert opts.frame_opts.blackman_coeff == 0.25 + + opts.frame_opts.snip_edges = False + assert opts.frame_opts.snip_edges is False + + +def test_set_get_mel_opts(): + opts = kaldifeat.MfccOptions() + + opts.mel_opts.num_bins = 100 + assert opts.mel_opts.num_bins == 100 + + opts.mel_opts.low_freq = 22 + assert opts.mel_opts.low_freq == 22 + + opts.mel_opts.high_freq = 1 + assert opts.mel_opts.high_freq == 1 + + opts.mel_opts.vtln_low = 101 + assert opts.mel_opts.vtln_low == 101 + + opts.mel_opts.vtln_high = -100 + assert opts.mel_opts.vtln_high == -100 + + opts.mel_opts.debug_mel = True + assert opts.mel_opts.debug_mel is True + + opts.mel_opts.htk_mode = True + assert opts.mel_opts.htk_mode is True + + +def test_from_empty_dict(): + opts = kaldifeat.MfccOptions.from_dict({}) + opts2 = kaldifeat.MfccOptions() + + assert str(opts) == str(opts2) + + +def test_from_dict_partial(): + d = { + "energy_floor": 10.5, + "htk_compat": True, + "mel_opts": {"num_bins": 80, "vtln_low": 1}, + "frame_opts": {"window_type": "hanning"}, + "device": "cuda:2", + } + opts = kaldifeat.MfccOptions.from_dict(d) + assert opts.energy_floor == 10.5 + assert opts.htk_compat is True + assert opts.device == torch.device("cuda", 2) + assert opts.mel_opts.num_bins == 80 + assert opts.mel_opts.vtln_low == 1 + assert opts.frame_opts.window_type == "hanning" + + mel_opts = kaldifeat.MelBanksOptions.from_dict(d["mel_opts"]) + assert str(opts.mel_opts) == str(mel_opts) + + +def test_from_dict_full_and_as_dict(): + opts = kaldifeat.MfccOptions() + opts.htk_compat = True + opts.mel_opts.num_bins = 80 + opts.frame_opts.samp_freq = 10 + + d = opts.as_dict() + assert d["htk_compat"] is True + assert d["mel_opts"]["num_bins"] == 80 + assert d["frame_opts"]["samp_freq"] == 10 + + mel_opts = kaldifeat.MelBanksOptions() + mel_opts.num_bins = 80 + assert d["mel_opts"] == mel_opts.as_dict() + + frame_opts = kaldifeat.FrameExtractionOptions() + frame_opts.samp_freq = 10 + assert d["frame_opts"] == frame_opts.as_dict() + + opts2 = kaldifeat.MfccOptions.from_dict(d) + assert str(opts2) == str(opts) + + d["htk_compat"] = False + d["device"] = torch.device("cuda", 10) + opts3 = kaldifeat.MfccOptions.from_dict(d) + assert opts3.htk_compat is False + assert opts3.device == torch.device("cuda", 10) + + +def main(): + test_default() + test_set_get() + test_set_get_frame_opts() + test_set_get_mel_opts() + test_from_empty_dict() + test_from_dict_partial() + test_from_dict_full_and_as_dict() + + +if __name__ == "__main__": + main() diff --git a/kaldifeat/python/tests/test_options.py b/kaldifeat/python/tests/test_options.py index 2016f3f..84e99cb 100755 --- a/kaldifeat/python/tests/test_options.py +++ b/kaldifeat/python/tests/test_options.py @@ -16,40 +16,6 @@ sys.path.insert(0, f"{kaldi_feat_dir}/build/lib") import kaldifeat -def test_mfcc_options(): - opts = kaldifeat.MfccOptions() - frame_opts = opts.frame_opts - mel_opts = opts.mel_opts - - opts.num_ceps = 10 - opts.use_energy = False - opts.energy_floor = 0.0 - opts.raw_energy = True - opts.cepstral_lifter = 22.0 - opts.htk_compat = False - opts.device = torch.device("cpu") - - frame_opts.blackman_coeff = 0.42 - frame_opts.dither = 1 - frame_opts.frame_length_ms = 25 - frame_opts.frame_shift_ms = 10 - frame_opts.preemph_coeff = 0.97 - frame_opts.remove_dc_offset = True - frame_opts.round_to_power_of_two = True - frame_opts.samp_freq = 16000 - frame_opts.snip_edges = True - frame_opts.window_type = "povey" - - mel_opts.debug_mel = True - mel_opts.high_freq = 0 - mel_opts.low_freq = 20 - mel_opts.num_bins = 23 - mel_opts.vtln_high = -500 - mel_opts.vtln_low = 100 - - print(opts) - - def test_spectogram_options(): opts = kaldifeat.SpectrogramOptions() opts.energy_floor = 0.0 @@ -107,7 +73,6 @@ def test_plp_options(): def main(): - test_mfcc_options() test_spectogram_options() test_plp_options()