diff --git a/README.md b/README.md index 22a4cb6..e3f456e 100644 --- a/README.md +++ b/README.md @@ -199,7 +199,12 @@ Please refer to - [kaldifeat/python/tests/test_mfcc.py](kaldifeat/python/tests/test_mfcc.py) - [kaldifeat/python/tests/test_plp.py](kaldifeat/python/tests/test_plp.py) - [kaldifeat/python/tests/test_spectrogram.py](kaldifeat/python/tests/test_spectrogram.py) - - [kaldifeat/python/tests/test_options.py](kaldifeat/python/tests/test_options.py) + - [kaldifeat/python/tests/test_frame_extraction_options.py](kaldifeat/python/tests/test_frame_extraction_options.py) + - [kaldifeat/python/tests/test_mel_bank_options.py](kaldifeat/python/tests/test_mel_bank_options.py) + - [kaldifeat/python/tests/test_fbank_options.py](kaldifeat/python/tests/test_fbank_options.py) + - [kaldifeat/python/tests/test_mfcc_options.py](kaldifeat/python/tests/test_mfcc_options.py) + - [kaldifeat/python/tests/test_spectrogram_options.py](kaldifeat/python/tests/test_spectrogram_options.py) + - [kaldifeat/python/tests/test_plp_options.py](kaldifeat/python/tests/test_plp_options.py) for more examples. diff --git a/kaldifeat/python/csrc/CMakeLists.txt b/kaldifeat/python/csrc/CMakeLists.txt index bd1ef25..affb69c 100644 --- a/kaldifeat/python/csrc/CMakeLists.txt +++ b/kaldifeat/python/csrc/CMakeLists.txt @@ -7,6 +7,7 @@ pybind11_add_module(_kaldifeat feature-window.cc kaldifeat.cc mel-computations.cc + utils.cc ) target_link_libraries(_kaldifeat PRIVATE kaldifeat_core) if(UNIX AND NOT APPLE) diff --git a/kaldifeat/python/csrc/feature-fbank.cc b/kaldifeat/python/csrc/feature-fbank.cc index 1cca393..8d8e8f0 100644 --- a/kaldifeat/python/csrc/feature-fbank.cc +++ b/kaldifeat/python/csrc/feature-fbank.cc @@ -7,6 +7,7 @@ #include #include "kaldifeat/csrc/feature-fbank.h" +#include "kaldifeat/python/csrc/utils.h" namespace kaldifeat { @@ -33,7 +34,12 @@ static void PybindFbankOptions(py::module &m) { self.device = torch::Device(s); }) .def("__str__", - [](const PyClass &self) -> std::string { return self.ToString(); }); + [](const PyClass &self) -> std::string { return self.ToString(); }) + .def("as_dict", + [](const PyClass &self) -> py::dict { return AsDict(self); }) + .def_static("from_dict", [](py::dict dict) -> PyClass { + return FbankOptionsFromDict(dict); + }); } static void PybindFbank(py::module &m) { diff --git a/kaldifeat/python/csrc/feature-mfcc.cc b/kaldifeat/python/csrc/feature-mfcc.cc index 2d48481..b43930c 100644 --- a/kaldifeat/python/csrc/feature-mfcc.cc +++ b/kaldifeat/python/csrc/feature-mfcc.cc @@ -7,6 +7,7 @@ #include #include "kaldifeat/csrc/feature-mfcc.h" +#include "kaldifeat/python/csrc/utils.h" namespace kaldifeat { @@ -33,7 +34,12 @@ void PybindMfccOptions(py::module &m) { self.device = torch::Device(s); }) .def("__str__", - [](const PyClass &self) -> std::string { return self.ToString(); }); + [](const PyClass &self) -> std::string { return self.ToString(); }) + .def("as_dict", + [](const PyClass &self) -> py::dict { return AsDict(self); }) + .def_static("from_dict", [](py::dict dict) -> PyClass { + return MfccOptionsFromDict(dict); + }); } static void PybindMfcc(py::module &m) { diff --git a/kaldifeat/python/csrc/feature-plp.cc b/kaldifeat/python/csrc/feature-plp.cc index d80ea03..ef68e2c 100644 --- a/kaldifeat/python/csrc/feature-plp.cc +++ b/kaldifeat/python/csrc/feature-plp.cc @@ -7,6 +7,7 @@ #include #include "kaldifeat/csrc/feature-plp.h" +#include "kaldifeat/python/csrc/utils.h" namespace kaldifeat { @@ -36,7 +37,12 @@ void PybindPlpOptions(py::module &m) { self.device = torch::Device(s); }) .def("__str__", - [](const PyClass &self) -> std::string { return self.ToString(); }); + [](const PyClass &self) -> std::string { return self.ToString(); }) + .def("as_dict", + [](const PyClass &self) -> py::dict { return AsDict(self); }) + .def_static("from_dict", [](py::dict dict) -> PyClass { + return PlpOptionsFromDict(dict); + }); } static void PybindPlp(py::module &m) { diff --git a/kaldifeat/python/csrc/feature-spectrogram.cc b/kaldifeat/python/csrc/feature-spectrogram.cc index ef39338..f752ebe 100644 --- a/kaldifeat/python/csrc/feature-spectrogram.cc +++ b/kaldifeat/python/csrc/feature-spectrogram.cc @@ -7,6 +7,7 @@ #include #include "kaldifeat/csrc/feature-spectrogram.h" +#include "kaldifeat/python/csrc/utils.h" namespace kaldifeat { @@ -30,7 +31,12 @@ static void PybindSpectrogramOptions(py::module &m) { self.device = torch::Device(s); }) .def("__str__", - [](const PyClass &self) -> std::string { return self.ToString(); }); + [](const PyClass &self) -> std::string { return self.ToString(); }) + .def("as_dict", + [](const PyClass &self) -> py::dict { return AsDict(self); }) + .def_static("from_dict", [](py::dict dict) -> PyClass { + return SpectrogramOptionsFromDict(dict); + }); } static void PybindSpectrogram(py::module &m) { diff --git a/kaldifeat/python/csrc/feature-window.cc b/kaldifeat/python/csrc/feature-window.cc index 5d76688..81cc275 100644 --- a/kaldifeat/python/csrc/feature-window.cc +++ b/kaldifeat/python/csrc/feature-window.cc @@ -7,6 +7,7 @@ #include #include "kaldifeat/csrc/feature-window.h" +#include "kaldifeat/python/csrc/utils.h" namespace kaldifeat { @@ -26,6 +27,14 @@ static void PybindFrameExtractionOptions(py::module &m) { &FrameExtractionOptions::round_to_power_of_two) .def_readwrite("blackman_coeff", &FrameExtractionOptions::blackman_coeff) .def_readwrite("snip_edges", &FrameExtractionOptions::snip_edges) + .def("as_dict", + [](const FrameExtractionOptions &self) -> py::dict { + return AsDict(self); + }) + .def_static("from_dict", + [](py::dict dict) -> FrameExtractionOptions { + return FrameExtractionOptionsFromDict(dict); + }) #if 0 .def_readwrite("allow_downsample", &FrameExtractionOptions::allow_downsample) diff --git a/kaldifeat/python/csrc/mel-computations.cc b/kaldifeat/python/csrc/mel-computations.cc index 572f0ba..77e692b 100644 --- a/kaldifeat/python/csrc/mel-computations.cc +++ b/kaldifeat/python/csrc/mel-computations.cc @@ -7,6 +7,7 @@ #include #include "kaldifeat/csrc/mel-computations.h" +#include "kaldifeat/python/csrc/utils.h" namespace kaldifeat { @@ -22,7 +23,12 @@ static void PybindMelBanksOptions(py::module &m) { .def_readwrite("debug_mel", &PyClass::debug_mel) .def_readwrite("htk_mode", &PyClass::htk_mode) .def("__str__", - [](const PyClass &self) -> std::string { return self.ToString(); }); + [](const PyClass &self) -> std::string { return self.ToString(); }) + .def("as_dict", + [](const PyClass &self) -> py::dict { return AsDict(self); }) + .def_static("from_dict", [](py::dict dict) -> PyClass { + return MelBanksOptionsFromDict(dict); + }); } void PybindMelComputations(py::module &m) { PybindMelBanksOptions(m); } diff --git a/kaldifeat/python/csrc/utils.cc b/kaldifeat/python/csrc/utils.cc new file mode 100644 index 0000000..76f47aa --- /dev/null +++ b/kaldifeat/python/csrc/utils.cc @@ -0,0 +1,253 @@ +// kaldifeat/python/csrc/utils.cc +// +// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +#include "kaldifeat/python/csrc/utils.h" + +#include + +#include "kaldifeat/csrc/feature-window.h" + +#define FROM_DICT(type, key) \ + if (dict.contains(#key)) { \ + opts.key = py::type(dict[#key]); \ + } + +#define AS_DICT(key) dict[#key] = opts.key + +namespace kaldifeat { + +FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict) { + FrameExtractionOptions opts; + + FROM_DICT(float_, samp_freq); + FROM_DICT(float_, frame_shift_ms); + FROM_DICT(float_, frame_length_ms); + FROM_DICT(float_, dither); + FROM_DICT(float_, preemph_coeff); + FROM_DICT(bool_, remove_dc_offset); + FROM_DICT(str, window_type); + FROM_DICT(bool_, round_to_power_of_two); + FROM_DICT(float_, blackman_coeff); + FROM_DICT(bool_, snip_edges); + + return opts; +} + +py::dict AsDict(const FrameExtractionOptions &opts) { + py::dict dict; + + AS_DICT(samp_freq); + AS_DICT(frame_shift_ms); + AS_DICT(frame_length_ms); + AS_DICT(dither); + AS_DICT(preemph_coeff); + AS_DICT(remove_dc_offset); + AS_DICT(window_type); + AS_DICT(round_to_power_of_two); + AS_DICT(blackman_coeff); + AS_DICT(snip_edges); + + return dict; +} + +MelBanksOptions MelBanksOptionsFromDict(py::dict dict) { + MelBanksOptions opts; + + FROM_DICT(int_, num_bins); + FROM_DICT(float_, low_freq); + FROM_DICT(float_, high_freq); + FROM_DICT(float_, vtln_low); + FROM_DICT(float_, vtln_high); + FROM_DICT(bool_, debug_mel); + FROM_DICT(bool_, htk_mode); + + return opts; +} +py::dict AsDict(const MelBanksOptions &opts) { + py::dict dict; + + AS_DICT(num_bins); + AS_DICT(low_freq); + AS_DICT(high_freq); + AS_DICT(vtln_low); + AS_DICT(vtln_high); + AS_DICT(debug_mel); + AS_DICT(htk_mode); + + return dict; +} + +FbankOptions FbankOptionsFromDict(py::dict dict) { + FbankOptions opts; + + if (dict.contains("frame_opts")) { + opts.frame_opts = FrameExtractionOptionsFromDict(dict["frame_opts"]); + } + + if (dict.contains("mel_opts")) { + opts.mel_opts = MelBanksOptionsFromDict(dict["mel_opts"]); + } + + FROM_DICT(bool_, use_energy); + FROM_DICT(float_, energy_floor); + FROM_DICT(bool_, raw_energy); + FROM_DICT(bool_, htk_compat); + FROM_DICT(bool_, use_log_fbank); + FROM_DICT(bool_, use_power); + + if (dict.contains("device")) { + opts.device = torch::Device(std::string(py::str(dict["device"]))); + } + + return opts; +} + +py::dict AsDict(const FbankOptions &opts) { + py::dict dict; + + dict["frame_opts"] = AsDict(opts.frame_opts); + dict["mel_opts"] = AsDict(opts.mel_opts); + AS_DICT(use_energy); + AS_DICT(energy_floor); + AS_DICT(raw_energy); + AS_DICT(htk_compat); + AS_DICT(use_log_fbank); + AS_DICT(use_power); + + auto torch_device = py::module_::import("torch").attr("device"); + dict["device"] = torch_device(opts.device.str()); + + return dict; +} + +MfccOptions MfccOptionsFromDict(py::dict dict) { + MfccOptions opts; + + if (dict.contains("frame_opts")) { + opts.frame_opts = FrameExtractionOptionsFromDict(dict["frame_opts"]); + } + + if (dict.contains("mel_opts")) { + opts.mel_opts = MelBanksOptionsFromDict(dict["mel_opts"]); + } + + FROM_DICT(int_, num_ceps); + FROM_DICT(bool_, use_energy); + FROM_DICT(float_, energy_floor); + FROM_DICT(bool_, raw_energy); + FROM_DICT(float_, cepstral_lifter); + FROM_DICT(bool_, htk_compat); + + if (dict.contains("device")) { + opts.device = torch::Device(std::string(py::str(dict["device"]))); + } + + return opts; +} + +py::dict AsDict(const MfccOptions &opts) { + py::dict dict; + + dict["frame_opts"] = AsDict(opts.frame_opts); + dict["mel_opts"] = AsDict(opts.mel_opts); + + AS_DICT(num_ceps); + AS_DICT(use_energy); + AS_DICT(energy_floor); + AS_DICT(raw_energy); + AS_DICT(cepstral_lifter); + AS_DICT(htk_compat); + + auto torch_device = py::module_::import("torch").attr("device"); + dict["device"] = torch_device(opts.device.str()); + + return dict; +} + +SpectrogramOptions SpectrogramOptionsFromDict(py::dict dict) { + SpectrogramOptions opts; + + if (dict.contains("frame_opts")) { + opts.frame_opts = FrameExtractionOptionsFromDict(dict["frame_opts"]); + } + + FROM_DICT(float_, energy_floor); + FROM_DICT(bool_, raw_energy); + // FROM_DICT(bool_, return_raw_fft); + + if (dict.contains("device")) { + opts.device = torch::Device(std::string(py::str(dict["device"]))); + } + + return opts; +} + +py::dict AsDict(const SpectrogramOptions &opts) { + py::dict dict; + + dict["frame_opts"] = AsDict(opts.frame_opts); + + AS_DICT(energy_floor); + AS_DICT(raw_energy); + + auto torch_device = py::module_::import("torch").attr("device"); + dict["device"] = torch_device(opts.device.str()); + + return dict; +} + +PlpOptions PlpOptionsFromDict(py::dict dict) { + PlpOptions opts; + + if (dict.contains("frame_opts")) { + opts.frame_opts = FrameExtractionOptionsFromDict(dict["frame_opts"]); + } + + if (dict.contains("mel_opts")) { + opts.mel_opts = MelBanksOptionsFromDict(dict["mel_opts"]); + } + + FROM_DICT(int_, lpc_order); + FROM_DICT(int_, num_ceps); + FROM_DICT(bool_, use_energy); + FROM_DICT(float_, energy_floor); + FROM_DICT(bool_, raw_energy); + FROM_DICT(float_, compress_factor); + FROM_DICT(int_, cepstral_lifter); + FROM_DICT(float_, cepstral_scale); + FROM_DICT(bool_, htk_compat); + + if (dict.contains("device")) { + opts.device = torch::Device(std::string(py::str(dict["device"]))); + } + + return opts; +} + +py::dict AsDict(const PlpOptions &opts) { + py::dict dict; + + dict["frame_opts"] = AsDict(opts.frame_opts); + dict["mel_opts"] = AsDict(opts.mel_opts); + + AS_DICT(lpc_order); + AS_DICT(num_ceps); + AS_DICT(use_energy); + AS_DICT(energy_floor); + AS_DICT(raw_energy); + AS_DICT(compress_factor); + AS_DICT(cepstral_lifter); + AS_DICT(cepstral_scale); + AS_DICT(htk_compat); + + auto torch_device = py::module_::import("torch").attr("device"); + dict["device"] = torch_device(opts.device.str()); + + return dict; +} + +#undef FROM_DICT +#undef AS_DICT + +} // namespace kaldifeat diff --git a/kaldifeat/python/csrc/utils.h b/kaldifeat/python/csrc/utils.h new file mode 100644 index 0000000..472d7c9 --- /dev/null +++ b/kaldifeat/python/csrc/utils.h @@ -0,0 +1,50 @@ +// kaldifeat/python/csrc/utils.h +// +// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +#ifndef KALDIFEAT_PYTHON_CSRC_UTILS_H_ +#define KALDIFEAT_PYTHON_CSRC_UTILS_H_ + +#include "kaldifeat/csrc/feature-fbank.h" +#include "kaldifeat/csrc/feature-mfcc.h" +#include "kaldifeat/csrc/feature-plp.h" +#include "kaldifeat/csrc/feature-spectrogram.h" +#include "kaldifeat/csrc/feature-window.h" +#include "kaldifeat/csrc/mel-computations.h" +#include "kaldifeat/python/csrc/kaldifeat.h" + +/* + * This file contains code about `from_dict` and + * `to_dict` for various options in kaldifeat. + * + * Regarding `from_dict`, users don't need to provide + * all the fields in the options. If some fields + * are not provided, it just uses the default one. + * + * If the provided dict in `from_dict` is empty, + * all fields use their default values. + */ + +namespace kaldifeat { + +FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict); +py::dict AsDict(const FrameExtractionOptions &opts); + +MelBanksOptions MelBanksOptionsFromDict(py::dict dict); +py::dict AsDict(const MelBanksOptions &opts); + +FbankOptions FbankOptionsFromDict(py::dict dict); +py::dict AsDict(const FbankOptions &opts); + +MfccOptions MfccOptionsFromDict(py::dict dict); +py::dict AsDict(const MfccOptions &opts); + +SpectrogramOptions SpectrogramOptionsFromDict(py::dict dict); +py::dict AsDict(const SpectrogramOptions &opts); + +PlpOptions PlpOptionsFromDict(py::dict dict); +py::dict AsDict(const PlpOptions &opts); + +} // namespace kaldifeat + +#endif // KALDIFEAT_PYTHON_CSRC_UTILS_H_ diff --git a/kaldifeat/python/tests/CMakeLists.txt b/kaldifeat/python/tests/CMakeLists.txt index 9961b14..4ccc891 100644 --- a/kaldifeat/python/tests/CMakeLists.txt +++ b/kaldifeat/python/tests/CMakeLists.txt @@ -18,10 +18,15 @@ endfunction() # please sort the files in alphabetic order set(py_test_files test_fbank.py + test_fbank_options.py + test_frame_extraction_options.py + test_mel_bank_options.py test_mfcc.py + test_mfcc_options.py test_plp.py + test_plp_options.py test_spectrogram.py - test_options.py + test_spectrogram_options.py ) foreach(source IN LISTS py_test_files) diff --git a/kaldifeat/python/tests/test_fbank_options.py b/kaldifeat/python/tests/test_fbank_options.py new file mode 100755 index 0000000..fb00054 --- /dev/null +++ b/kaldifeat/python/tests/test_fbank_options.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + + +import torch + +import kaldifeat + + +def test_default(): + opts = kaldifeat.FbankOptions() + assert opts.frame_opts.samp_freq == 16000 + assert opts.frame_opts.frame_shift_ms == 10.0 + assert opts.frame_opts.frame_length_ms == 25.0 + assert opts.frame_opts.dither == 1.0 + assert abs(opts.frame_opts.preemph_coeff - 0.97) < 1e-6 + assert opts.frame_opts.remove_dc_offset is True + assert opts.frame_opts.window_type == "povey" + assert opts.frame_opts.round_to_power_of_two is True + assert abs(opts.frame_opts.blackman_coeff - 0.42) < 1e-6 + assert opts.frame_opts.snip_edges is True + + assert opts.mel_opts.num_bins == 23 + assert opts.mel_opts.low_freq == 20 + assert opts.mel_opts.high_freq == 0 + assert opts.mel_opts.vtln_low == 100 + assert opts.mel_opts.vtln_high == -500 + assert opts.mel_opts.debug_mel is False + assert opts.mel_opts.htk_mode is False + + assert opts.use_energy is False + assert opts.energy_floor == 0.0 + assert opts.raw_energy is True + assert opts.htk_compat is False + assert opts.use_log_fbank is True + assert opts.use_power is True + assert opts.device.type == "cpu" + + +def test_set_get(): + opts = kaldifeat.FbankOptions() + opts.use_energy = True + assert opts.use_energy is True + + opts.energy_floor = 1 + assert opts.energy_floor == 1 + + opts.raw_energy = False + assert opts.raw_energy is False + + opts.htk_compat = True + assert opts.htk_compat is True + + opts.use_log_fbank = False + assert opts.use_log_fbank is False + + opts.use_power = False + assert opts.use_power is False + + opts.device = torch.device("cuda", 1) + assert opts.device.type == "cuda" + assert opts.device.index == 1 + + +def test_set_get_frame_opts(): + opts = kaldifeat.FbankOptions() + + opts.frame_opts.samp_freq = 44100 + assert opts.frame_opts.samp_freq == 44100 + + opts.frame_opts.frame_shift_ms = 20.5 + assert opts.frame_opts.frame_shift_ms == 20.5 + + opts.frame_opts.frame_length_ms = 1 + assert opts.frame_opts.frame_length_ms == 1 + + opts.frame_opts.dither = 0.5 + assert opts.frame_opts.dither == 0.5 + + opts.frame_opts.preemph_coeff = 0.25 + assert opts.frame_opts.preemph_coeff == 0.25 + + opts.frame_opts.remove_dc_offset = False + assert opts.frame_opts.remove_dc_offset is False + + opts.frame_opts.window_type = "hanning" + assert opts.frame_opts.window_type == "hanning" + + opts.frame_opts.round_to_power_of_two = False + assert opts.frame_opts.round_to_power_of_two is False + + opts.frame_opts.blackman_coeff = 0.25 + assert opts.frame_opts.blackman_coeff == 0.25 + + opts.frame_opts.snip_edges = False + assert opts.frame_opts.snip_edges is False + + +def test_set_get_mel_opts(): + opts = kaldifeat.FbankOptions() + + opts.mel_opts.num_bins = 100 + assert opts.mel_opts.num_bins == 100 + + opts.mel_opts.low_freq = 22 + assert opts.mel_opts.low_freq == 22 + + opts.mel_opts.high_freq = 1 + assert opts.mel_opts.high_freq == 1 + + opts.mel_opts.vtln_low = 101 + assert opts.mel_opts.vtln_low == 101 + + opts.mel_opts.vtln_high = -100 + assert opts.mel_opts.vtln_high == -100 + + opts.mel_opts.debug_mel = True + assert opts.mel_opts.debug_mel is True + + opts.mel_opts.htk_mode = True + assert opts.mel_opts.htk_mode is True + + +def test_from_empty_dict(): + opts = kaldifeat.FbankOptions.from_dict({}) + opts2 = kaldifeat.FbankOptions() + + assert str(opts) == str(opts2) + + +def test_from_dict_partial(): + d = { + "energy_floor": 10.5, + "htk_compat": True, + "mel_opts": {"num_bins": 80, "vtln_low": 1}, + "frame_opts": {"window_type": "hanning"}, + } + opts = kaldifeat.FbankOptions.from_dict(d) + assert opts.energy_floor == 10.5 + assert opts.htk_compat is True + assert opts.mel_opts.num_bins == 80 + assert opts.mel_opts.vtln_low == 1 + assert opts.frame_opts.window_type == "hanning" + + mel_opts = kaldifeat.MelBanksOptions.from_dict(d["mel_opts"]) + assert str(opts.mel_opts) == str(mel_opts) + + +def test_from_dict_full_and_as_dict(): + opts = kaldifeat.FbankOptions() + opts.htk_compat = True + opts.mel_opts.num_bins = 80 + opts.frame_opts.samp_freq = 10 + + d = opts.as_dict() + assert d["htk_compat"] is True + assert d["mel_opts"]["num_bins"] == 80 + assert d["frame_opts"]["samp_freq"] == 10 + + mel_opts = kaldifeat.MelBanksOptions() + mel_opts.num_bins = 80 + assert d["mel_opts"] == mel_opts.as_dict() + + frame_opts = kaldifeat.FrameExtractionOptions() + frame_opts.samp_freq = 10 + assert d["frame_opts"] == frame_opts.as_dict() + + opts2 = kaldifeat.FbankOptions.from_dict(d) + assert str(opts2) == str(opts) + + d["htk_compat"] = False + d["device"] = torch.device("cuda", 2) + opts3 = kaldifeat.FbankOptions.from_dict(d) + assert opts3.htk_compat is False + assert opts3.device == torch.device("cuda", 2) + + +def main(): + test_default() + test_set_get() + test_set_get_frame_opts() + test_set_get_mel_opts() + test_from_empty_dict() + test_from_dict_partial() + test_from_dict_full_and_as_dict() + + +if __name__ == "__main__": + main() diff --git a/kaldifeat/python/tests/test_frame_extraction_options.py b/kaldifeat/python/tests/test_frame_extraction_options.py new file mode 100755 index 0000000..4fa90d9 --- /dev/null +++ b/kaldifeat/python/tests/test_frame_extraction_options.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +import kaldifeat + + +def test_default(): + opts = kaldifeat.FrameExtractionOptions() + assert opts.samp_freq == 16000 + assert opts.frame_shift_ms == 10.0 + assert opts.frame_length_ms == 25.0 + assert opts.dither == 1.0 + assert abs(opts.preemph_coeff - 0.97) < 1e-6 + assert opts.remove_dc_offset is True + assert opts.window_type == "povey" + assert opts.round_to_power_of_two is True + assert abs(opts.blackman_coeff - 0.42) < 1e-6 + assert opts.snip_edges is True + + +def test_set_get(): + opts = kaldifeat.FrameExtractionOptions() + opts.samp_freq = 44100 + assert opts.samp_freq == 44100 + + opts.frame_shift_ms = 20.5 + assert opts.frame_shift_ms == 20.5 + + opts.frame_length_ms = 1 + assert opts.frame_length_ms == 1 + + opts.dither = 0.5 + assert opts.dither == 0.5 + + opts.preemph_coeff = 0.25 + assert opts.preemph_coeff == 0.25 + + opts.remove_dc_offset = False + assert opts.remove_dc_offset is False + + opts.window_type = "hanning" + assert opts.window_type == "hanning" + + opts.round_to_power_of_two = False + assert opts.round_to_power_of_two is False + + opts.blackman_coeff = 0.25 + assert opts.blackman_coeff == 0.25 + + opts.snip_edges = False + assert opts.snip_edges is False + + +def test_from_empty_dict(): + opts = kaldifeat.FrameExtractionOptions.from_dict({}) + opts2 = kaldifeat.FrameExtractionOptions() + + assert str(opts) == str(opts2) + + +def test_from_dict_partial(): + d = {"samp_freq": 10, "frame_shift_ms": 2} + + opts = kaldifeat.FrameExtractionOptions.from_dict(d) + + opts2 = kaldifeat.FrameExtractionOptions() + assert str(opts) != str(opts2) + + opts2.samp_freq = 10 + assert str(opts) != str(opts2) + + opts2.frame_shift_ms = 2 + assert str(opts) == str(opts2) + + opts2.frame_shift_ms = 3 + assert str(opts) != str(opts2) + + +def test_from_dict_full_and_as_dict(): + opts = kaldifeat.FrameExtractionOptions() + opts.samp_freq = 20 + opts.frame_length_ms = 100 + + d = opts.as_dict() + for key, value in d.items(): + assert value == getattr(opts, key) + + opts2 = kaldifeat.FrameExtractionOptions.from_dict(d) + assert str(opts2) == str(opts) + + d["window_type"] = "hanning" + opts3 = kaldifeat.FrameExtractionOptions.from_dict(d) + assert opts3.window_type == "hanning" + + +def main(): + test_default() + test_set_get() + test_from_empty_dict() + test_from_dict_partial() + test_from_dict_full_and_as_dict() + + +if __name__ == "__main__": + main() diff --git a/kaldifeat/python/tests/test_mel_bank_options.py b/kaldifeat/python/tests/test_mel_bank_options.py new file mode 100755 index 0000000..bb2924f --- /dev/null +++ b/kaldifeat/python/tests/test_mel_bank_options.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +import kaldifeat + + +def test_default(): + opts = kaldifeat.MelBanksOptions() + assert opts.num_bins == 25 + assert opts.low_freq == 20 + assert opts.high_freq == 0 + assert opts.vtln_low == 100 + assert opts.vtln_high == -500 + assert opts.debug_mel is False + assert opts.htk_mode is False + + +def test_set_get(): + opts = kaldifeat.MelBanksOptions() + opts.num_bins = 100 + assert opts.num_bins == 100 + + opts.low_freq = 22 + assert opts.low_freq == 22 + + opts.high_freq = 1 + assert opts.high_freq == 1 + + opts.vtln_low = 101 + assert opts.vtln_low == 101 + + opts.vtln_high = -100 + assert opts.vtln_high == -100 + + opts.debug_mel = True + assert opts.debug_mel is True + + opts.htk_mode = True + assert opts.htk_mode is True + + +def test_from_empty_dict(): + opts = kaldifeat.MelBanksOptions.from_dict({}) + opts2 = kaldifeat.MelBanksOptions() + + assert str(opts) == str(opts2) + + +def test_from_dict_partial(): + d = {"num_bins": 10, "debug_mel": True} + + opts = kaldifeat.MelBanksOptions.from_dict(d) + + opts2 = kaldifeat.MelBanksOptions() + assert str(opts) != str(opts2) + + opts2.num_bins = 10 + assert str(opts) != str(opts2) + + opts2.debug_mel = True + assert str(opts) == str(opts2) + + opts2.debug_mel = False + assert str(opts) != str(opts2) + + +def test_from_dict_full_and_as_dict(): + opts = kaldifeat.MelBanksOptions() + opts.num_bins = 80 + opts.vtln_high = 2 + + d = opts.as_dict() + for key, value in d.items(): + assert value == getattr(opts, key) + + opts2 = kaldifeat.MelBanksOptions.from_dict(d) + assert str(opts2) == str(opts) + + d["htk_mode"] = True + opts3 = kaldifeat.MelBanksOptions.from_dict(d) + assert opts3.htk_mode is True + + +def main(): + test_default() + test_set_get() + test_from_empty_dict() + test_from_dict_partial() + test_from_dict_full_and_as_dict() + + +if __name__ == "__main__": + main() diff --git a/kaldifeat/python/tests/test_mfcc_options.py b/kaldifeat/python/tests/test_mfcc_options.py new file mode 100755 index 0000000..310ff93 --- /dev/null +++ b/kaldifeat/python/tests/test_mfcc_options.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + + +import torch + +import kaldifeat + + +def test_default(): + opts = kaldifeat.MfccOptions() + + assert opts.frame_opts.samp_freq == 16000 + assert opts.frame_opts.frame_shift_ms == 10.0 + assert opts.frame_opts.frame_length_ms == 25.0 + assert opts.frame_opts.dither == 1.0 + assert abs(opts.frame_opts.preemph_coeff - 0.97) < 1e-6 + assert opts.frame_opts.remove_dc_offset is True + assert opts.frame_opts.window_type == "povey" + assert opts.frame_opts.round_to_power_of_two is True + assert abs(opts.frame_opts.blackman_coeff - 0.42) < 1e-6 + assert opts.frame_opts.snip_edges is True + + assert opts.mel_opts.num_bins == 23 + assert opts.mel_opts.low_freq == 20 + assert opts.mel_opts.high_freq == 0 + assert opts.mel_opts.vtln_low == 100 + assert opts.mel_opts.vtln_high == -500 + assert opts.mel_opts.debug_mel is False + assert opts.mel_opts.htk_mode is False + + assert opts.num_ceps == 13 + assert opts.use_energy is True + assert opts.energy_floor == 0 + assert opts.raw_energy is True + assert opts.cepstral_lifter == 22.0 + assert opts.htk_compat is False + + assert opts.device.type == "cpu" + + +def test_set_get(): + opts = kaldifeat.MfccOptions() + opts.num_ceps = 22 + assert opts.num_ceps == 22 + + opts.use_energy = False + assert opts.use_energy is False + + opts.energy_floor = 1 + assert opts.energy_floor == 1 + + opts.raw_energy = False + assert opts.raw_energy is False + + opts.cepstral_lifter = 21 + assert opts.cepstral_lifter == 21 + + opts.htk_compat = True + assert opts.htk_compat is True + + opts.device = torch.device("cuda", 1) + assert opts.device.type == "cuda" + assert opts.device.index == 1 + + +def test_set_get_frame_opts(): + opts = kaldifeat.MfccOptions() + + opts.frame_opts.samp_freq = 44100 + assert opts.frame_opts.samp_freq == 44100 + + opts.frame_opts.frame_shift_ms = 20.5 + assert opts.frame_opts.frame_shift_ms == 20.5 + + opts.frame_opts.frame_length_ms = 1 + assert opts.frame_opts.frame_length_ms == 1 + + opts.frame_opts.dither = 0.5 + assert opts.frame_opts.dither == 0.5 + + opts.frame_opts.preemph_coeff = 0.25 + assert opts.frame_opts.preemph_coeff == 0.25 + + opts.frame_opts.remove_dc_offset = False + assert opts.frame_opts.remove_dc_offset is False + + opts.frame_opts.window_type = "hanning" + assert opts.frame_opts.window_type == "hanning" + + opts.frame_opts.round_to_power_of_two = False + assert opts.frame_opts.round_to_power_of_two is False + + opts.frame_opts.blackman_coeff = 0.25 + assert opts.frame_opts.blackman_coeff == 0.25 + + opts.frame_opts.snip_edges = False + assert opts.frame_opts.snip_edges is False + + +def test_set_get_mel_opts(): + opts = kaldifeat.MfccOptions() + + opts.mel_opts.num_bins = 100 + assert opts.mel_opts.num_bins == 100 + + opts.mel_opts.low_freq = 22 + assert opts.mel_opts.low_freq == 22 + + opts.mel_opts.high_freq = 1 + assert opts.mel_opts.high_freq == 1 + + opts.mel_opts.vtln_low = 101 + assert opts.mel_opts.vtln_low == 101 + + opts.mel_opts.vtln_high = -100 + assert opts.mel_opts.vtln_high == -100 + + opts.mel_opts.debug_mel = True + assert opts.mel_opts.debug_mel is True + + opts.mel_opts.htk_mode = True + assert opts.mel_opts.htk_mode is True + + +def test_from_empty_dict(): + opts = kaldifeat.MfccOptions.from_dict({}) + opts2 = kaldifeat.MfccOptions() + + assert str(opts) == str(opts2) + + +def test_from_dict_partial(): + d = { + "energy_floor": 10.5, + "htk_compat": True, + "mel_opts": {"num_bins": 80, "vtln_low": 1}, + "frame_opts": {"window_type": "hanning"}, + "device": "cuda:2", + } + opts = kaldifeat.MfccOptions.from_dict(d) + assert opts.energy_floor == 10.5 + assert opts.htk_compat is True + assert opts.device == torch.device("cuda", 2) + assert opts.mel_opts.num_bins == 80 + assert opts.mel_opts.vtln_low == 1 + assert opts.frame_opts.window_type == "hanning" + + mel_opts = kaldifeat.MelBanksOptions.from_dict(d["mel_opts"]) + assert str(opts.mel_opts) == str(mel_opts) + + +def test_from_dict_full_and_as_dict(): + opts = kaldifeat.MfccOptions() + opts.htk_compat = True + opts.mel_opts.num_bins = 80 + opts.frame_opts.samp_freq = 10 + + d = opts.as_dict() + assert d["htk_compat"] is True + assert d["mel_opts"]["num_bins"] == 80 + assert d["frame_opts"]["samp_freq"] == 10 + + mel_opts = kaldifeat.MelBanksOptions() + mel_opts.num_bins = 80 + assert d["mel_opts"] == mel_opts.as_dict() + + frame_opts = kaldifeat.FrameExtractionOptions() + frame_opts.samp_freq = 10 + assert d["frame_opts"] == frame_opts.as_dict() + + opts2 = kaldifeat.MfccOptions.from_dict(d) + assert str(opts2) == str(opts) + + d["htk_compat"] = False + d["device"] = torch.device("cuda", 10) + opts3 = kaldifeat.MfccOptions.from_dict(d) + assert opts3.htk_compat is False + assert opts3.device == torch.device("cuda", 10) + + +def main(): + test_default() + test_set_get() + test_set_get_frame_opts() + test_set_get_mel_opts() + test_from_empty_dict() + test_from_dict_partial() + test_from_dict_full_and_as_dict() + + +if __name__ == "__main__": + main() diff --git a/kaldifeat/python/tests/test_options.py b/kaldifeat/python/tests/test_options.py deleted file mode 100755 index e048629..0000000 --- a/kaldifeat/python/tests/test_options.py +++ /dev/null @@ -1,180 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) - -import sys -from pathlib import Path - -cur_dir = Path(__file__).resolve().parent -kaldi_feat_dir = cur_dir.parent.parent.parent - - -import torch - -sys.path.insert(0, f"{kaldi_feat_dir}/build/lib") - -import kaldifeat - - -def test_frame_extraction_options(): - opts = kaldifeat.FrameExtractionOptions() - opts.samp_freq = 220500 - opts.frame_shift_ms = 15 - opts.frame_length_ms = 40 - opts.dither = 0.1 - opts.preemph_coeff = 0.98 - opts.remove_dc_offset = False - opts.window_type = "hanning" - opts.round_to_power_of_two = False - opts.blackman_coeff = 0.422 - opts.snip_edges = False - print(opts) - - -def test_mel_banks_options(): - opts = kaldifeat.MelBanksOptions() - opts.num_bins = 23 - opts.low_freq = 21 - opts.high_freq = 8000 - opts.vtln_low = 101 - opts.vtln_high = -501 - opts.debug_mel = True - opts.htk_mode = True - print(opts) - - -def test_fbank_options(): - opts = kaldifeat.FbankOptions() - frame_opts = opts.frame_opts - mel_opts = opts.mel_opts - - opts.energy_floor = 0 - opts.htk_compat = False - opts.raw_energy = True - opts.use_energy = False - opts.use_log_fbank = True - opts.use_power = True - opts.device = torch.device("cuda", 0) - - frame_opts.blackman_coeff = 0.42 - frame_opts.dither = 1 - frame_opts.frame_length_ms = 25 - frame_opts.frame_shift_ms = 10 - frame_opts.preemph_coeff = 0.97 - frame_opts.remove_dc_offset = True - frame_opts.round_to_power_of_two = True - frame_opts.samp_freq = 16000 - frame_opts.snip_edges = True - frame_opts.window_type = "povey" - - mel_opts.debug_mel = True - mel_opts.high_freq = 0 - mel_opts.low_freq = 20 - mel_opts.num_bins = 23 - mel_opts.vtln_high = -500 - mel_opts.vtln_low = 100 - - print(opts) - - -def test_mfcc_options(): - opts = kaldifeat.MfccOptions() - frame_opts = opts.frame_opts - mel_opts = opts.mel_opts - - opts.num_ceps = 10 - opts.use_energy = False - opts.energy_floor = 0.0 - opts.raw_energy = True - opts.cepstral_lifter = 22.0 - opts.htk_compat = False - opts.device = torch.device("cpu") - - frame_opts.blackman_coeff = 0.42 - frame_opts.dither = 1 - frame_opts.frame_length_ms = 25 - frame_opts.frame_shift_ms = 10 - frame_opts.preemph_coeff = 0.97 - frame_opts.remove_dc_offset = True - frame_opts.round_to_power_of_two = True - frame_opts.samp_freq = 16000 - frame_opts.snip_edges = True - frame_opts.window_type = "povey" - - mel_opts.debug_mel = True - mel_opts.high_freq = 0 - mel_opts.low_freq = 20 - mel_opts.num_bins = 23 - mel_opts.vtln_high = -500 - mel_opts.vtln_low = 100 - - print(opts) - - -def test_spectogram_options(): - opts = kaldifeat.SpectrogramOptions() - opts.energy_floor = 0.0 - opts.raw_energy = True - - frame_opts = opts.frame_opts - frame_opts.blackman_coeff = 0.42 - frame_opts.dither = 1 - frame_opts.frame_length_ms = 25 - frame_opts.frame_shift_ms = 10 - frame_opts.preemph_coeff = 0.97 - frame_opts.remove_dc_offset = True - frame_opts.round_to_power_of_two = True - frame_opts.samp_freq = 16000 - frame_opts.snip_edges = True - frame_opts.window_type = "povey" - - print(opts) - - -def test_plp_options(): - opts = kaldifeat.PlpOptions() - opts.lpc_order = 12 - opts.num_ceps = 13 - opts.use_energy = True - opts.energy_floor = 0.0 - opts.raw_energy = True - opts.compress_factor = 0.33333 - opts.cepstral_lifter = 22 - opts.cepstral_scale = 1.0 - opts.htk_compat = False - opts.device = torch.device("cpu") - - frame_opts = opts.frame_opts - frame_opts.blackman_coeff = 0.42 - frame_opts.dither = 1 - frame_opts.frame_length_ms = 25 - frame_opts.frame_shift_ms = 10 - frame_opts.preemph_coeff = 0.97 - frame_opts.remove_dc_offset = True - frame_opts.round_to_power_of_two = True - frame_opts.samp_freq = 16000 - frame_opts.snip_edges = True - frame_opts.window_type = "povey" - - mel_opts = opts.mel_opts - mel_opts.debug_mel = True - mel_opts.high_freq = 0 - mel_opts.low_freq = 20 - mel_opts.num_bins = 23 - mel_opts.vtln_high = -500 - mel_opts.vtln_low = 100 - - print(opts) - - -def main(): - test_frame_extraction_options() - test_mel_banks_options() - test_fbank_options() - test_mfcc_options() - test_spectogram_options() - test_plp_options() - - -if __name__ == "__main__": - main() diff --git a/kaldifeat/python/tests/test_plp_options.py b/kaldifeat/python/tests/test_plp_options.py new file mode 100755 index 0000000..dc87045 --- /dev/null +++ b/kaldifeat/python/tests/test_plp_options.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + + +import torch + +import kaldifeat + + +def test_default(): + opts = kaldifeat.PlpOptions() + assert opts.frame_opts.samp_freq == 16000 + assert opts.frame_opts.frame_shift_ms == 10.0 + assert opts.frame_opts.frame_length_ms == 25.0 + assert opts.frame_opts.dither == 1.0 + assert abs(opts.frame_opts.preemph_coeff - 0.97) < 1e-6 + assert opts.frame_opts.remove_dc_offset is True + assert opts.frame_opts.window_type == "povey" + assert opts.frame_opts.round_to_power_of_two is True + assert abs(opts.frame_opts.blackman_coeff - 0.42) < 1e-6 + assert opts.frame_opts.snip_edges is True + + assert opts.mel_opts.num_bins == 23 + assert opts.mel_opts.low_freq == 20 + assert opts.mel_opts.high_freq == 0 + assert opts.mel_opts.vtln_low == 100 + assert opts.mel_opts.vtln_high == -500 + assert opts.mel_opts.debug_mel is False + assert opts.mel_opts.htk_mode is False + + assert opts.lpc_order == 12 + assert opts.num_ceps == 13 + assert opts.use_energy is True + assert opts.energy_floor == 0.0 + assert opts.raw_energy is True + assert abs(opts.compress_factor - 0.33333) < 1e-6 + assert opts.cepstral_lifter == 22 + assert opts.cepstral_scale == 1.0 + assert opts.htk_compat is False + assert opts.device.type == "cpu" + + +def test_set_get(): + opts = kaldifeat.PlpOptions() + + opts.lpc_order = 11 + assert opts.lpc_order == 11 + + opts.num_ceps = 1 + assert opts.num_ceps == 1 + + opts.use_energy = False + assert opts.use_energy is False + + opts.energy_floor = 1 + assert opts.energy_floor == 1 + + opts.raw_energy = False + assert opts.raw_energy is False + + opts.compress_factor = 0.5 + assert opts.compress_factor == 0.5 + + opts.cepstral_lifter = 2 + assert opts.cepstral_lifter == 2 + + opts.cepstral_scale = 3 + assert opts.cepstral_scale == 3 + + opts.htk_compat = True + assert opts.htk_compat is True + + opts.device = "cuda:10" + assert opts.device == torch.device("cuda", 10) + + +def test_set_get_frame_opts(): + opts = kaldifeat.PlpOptions() + + opts.frame_opts.samp_freq = 44100 + assert opts.frame_opts.samp_freq == 44100 + + opts.frame_opts.frame_shift_ms = 20.5 + assert opts.frame_opts.frame_shift_ms == 20.5 + + opts.frame_opts.frame_length_ms = 1 + assert opts.frame_opts.frame_length_ms == 1 + + opts.frame_opts.dither = 0.5 + assert opts.frame_opts.dither == 0.5 + + opts.frame_opts.preemph_coeff = 0.25 + assert opts.frame_opts.preemph_coeff == 0.25 + + opts.frame_opts.remove_dc_offset = False + assert opts.frame_opts.remove_dc_offset is False + + opts.frame_opts.window_type = "hanning" + assert opts.frame_opts.window_type == "hanning" + + opts.frame_opts.round_to_power_of_two = False + assert opts.frame_opts.round_to_power_of_two is False + + opts.frame_opts.blackman_coeff = 0.25 + assert opts.frame_opts.blackman_coeff == 0.25 + + opts.frame_opts.snip_edges = False + assert opts.frame_opts.snip_edges is False + + +def test_set_get_mel_opts(): + opts = kaldifeat.PlpOptions() + + opts.mel_opts.num_bins = 100 + assert opts.mel_opts.num_bins == 100 + + opts.mel_opts.low_freq = 22 + assert opts.mel_opts.low_freq == 22 + + opts.mel_opts.high_freq = 1 + assert opts.mel_opts.high_freq == 1 + + opts.mel_opts.vtln_low = 101 + assert opts.mel_opts.vtln_low == 101 + + opts.mel_opts.vtln_high = -100 + assert opts.mel_opts.vtln_high == -100 + + opts.mel_opts.debug_mel = True + assert opts.mel_opts.debug_mel is True + + opts.mel_opts.htk_mode = True + assert opts.mel_opts.htk_mode is True + + +def test_from_empty_dict(): + opts = kaldifeat.PlpOptions.from_dict({}) + opts2 = kaldifeat.PlpOptions() + + assert str(opts) == str(opts2) + + +def test_from_dict_partial(): + d = { + "energy_floor": 10.5, + "htk_compat": True, + "mel_opts": {"num_bins": 80, "vtln_low": 1}, + "frame_opts": {"window_type": "hanning"}, + } + opts = kaldifeat.PlpOptions.from_dict(d) + assert opts.energy_floor == 10.5 + assert opts.htk_compat is True + assert opts.mel_opts.num_bins == 80 + assert opts.mel_opts.vtln_low == 1 + assert opts.frame_opts.window_type == "hanning" + + mel_opts = kaldifeat.MelBanksOptions.from_dict(d["mel_opts"]) + assert str(opts.mel_opts) == str(mel_opts) + + frame_opts = kaldifeat.FrameExtractionOptions.from_dict(d["frame_opts"]) + assert str(opts.frame_opts) == str(frame_opts) + + +def test_from_dict_full_and_as_dict(): + opts = kaldifeat.PlpOptions() + opts.htk_compat = True + opts.mel_opts.num_bins = 80 + opts.frame_opts.samp_freq = 10 + + d = opts.as_dict() + assert d["htk_compat"] is True + assert d["mel_opts"]["num_bins"] == 80 + assert d["frame_opts"]["samp_freq"] == 10 + + mel_opts = kaldifeat.MelBanksOptions() + mel_opts.num_bins = 80 + assert d["mel_opts"] == mel_opts.as_dict() + + frame_opts = kaldifeat.FrameExtractionOptions() + frame_opts.samp_freq = 10 + assert d["frame_opts"] == frame_opts.as_dict() + + opts2 = kaldifeat.PlpOptions.from_dict(d) + assert str(opts2) == str(opts) + + d["htk_compat"] = False + d["device"] = torch.device("cuda", 2) + opts3 = kaldifeat.PlpOptions.from_dict(d) + assert opts3.htk_compat is False + assert opts3.device == torch.device("cuda", 2) + + +def main(): + test_default() + test_set_get() + test_set_get_frame_opts() + test_set_get_mel_opts() + test_from_empty_dict() + test_from_dict_partial() + test_from_dict_full_and_as_dict() + + +if __name__ == "__main__": + main() diff --git a/kaldifeat/python/tests/test_spectrogram_options.py b/kaldifeat/python/tests/test_spectrogram_options.py new file mode 100755 index 0000000..d830300 --- /dev/null +++ b/kaldifeat/python/tests/test_spectrogram_options.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + + +import torch + +import kaldifeat + + +def test_default(): + opts = kaldifeat.SpectrogramOptions() + + assert opts.frame_opts.samp_freq == 16000 + assert opts.frame_opts.frame_shift_ms == 10.0 + assert opts.frame_opts.frame_length_ms == 25.0 + assert opts.frame_opts.dither == 1.0 + assert abs(opts.frame_opts.preemph_coeff - 0.97) < 1e-6 + assert opts.frame_opts.remove_dc_offset is True + assert opts.frame_opts.window_type == "povey" + assert opts.frame_opts.round_to_power_of_two is True + assert abs(opts.frame_opts.blackman_coeff - 0.42) < 1e-6 + assert opts.frame_opts.snip_edges is True + + assert opts.energy_floor == 0 + assert opts.raw_energy is True + assert opts.device.type == "cpu" + + +def test_set_get(): + opts = kaldifeat.SpectrogramOptions() + + opts.energy_floor = 1 + assert opts.energy_floor == 1 + + opts.raw_energy = False + assert opts.raw_energy is False + + opts.device = torch.device("cuda", 1) + assert opts.device.type == "cuda" + assert opts.device.index == 1 + + +def test_set_get_frame_opts(): + opts = kaldifeat.SpectrogramOptions() + + opts.frame_opts.samp_freq = 44100 + assert opts.frame_opts.samp_freq == 44100 + + opts.frame_opts.frame_shift_ms = 20.5 + assert opts.frame_opts.frame_shift_ms == 20.5 + + opts.frame_opts.frame_length_ms = 1 + assert opts.frame_opts.frame_length_ms == 1 + + opts.frame_opts.dither = 0.5 + assert opts.frame_opts.dither == 0.5 + + opts.frame_opts.preemph_coeff = 0.25 + assert opts.frame_opts.preemph_coeff == 0.25 + + opts.frame_opts.remove_dc_offset = False + assert opts.frame_opts.remove_dc_offset is False + + opts.frame_opts.window_type = "hanning" + assert opts.frame_opts.window_type == "hanning" + + opts.frame_opts.round_to_power_of_two = False + assert opts.frame_opts.round_to_power_of_two is False + + opts.frame_opts.blackman_coeff = 0.25 + assert opts.frame_opts.blackman_coeff == 0.25 + + opts.frame_opts.snip_edges = False + assert opts.frame_opts.snip_edges is False + + +def test_from_empty_dict(): + opts = kaldifeat.SpectrogramOptions.from_dict({}) + opts2 = kaldifeat.SpectrogramOptions() + + assert str(opts) == str(opts2) + + +def test_from_dict_partial(): + d = { + "energy_floor": 10.5, + "frame_opts": {"window_type": "hanning"}, + "device": "cuda:2", + } + opts = kaldifeat.SpectrogramOptions.from_dict(d) + assert opts.energy_floor == 10.5 + assert opts.device == torch.device("cuda", 2) + assert opts.frame_opts.window_type == "hanning" + + frame_opts = kaldifeat.FrameExtractionOptions.from_dict(d["frame_opts"]) + assert str(opts.frame_opts) == str(frame_opts) + + +def test_from_dict_full_and_as_dict(): + opts = kaldifeat.SpectrogramOptions() + opts.frame_opts.samp_freq = 12 + opts.device = "cuda:3" + + d = opts.as_dict() + assert d["frame_opts"]["samp_freq"] == 12 + assert d["device"] == torch.device("cuda:3") + + frame_opts = kaldifeat.FrameExtractionOptions() + frame_opts.samp_freq = 12 + assert d["frame_opts"] == frame_opts.as_dict() + + opts2 = kaldifeat.SpectrogramOptions.from_dict(d) + assert str(opts2) == str(opts) + + d["device"] = torch.device("cuda", 10) + opts3 = kaldifeat.SpectrogramOptions.from_dict(d) + assert opts3.device == torch.device("cuda", 10) + + opts.device = "cuda:10" + assert str(opts3) == str(opts) + + +def main(): + test_default() + test_set_get() + test_set_get_frame_opts() + test_from_empty_dict() + test_from_dict_partial() + test_from_dict_full_and_as_dict() + + +if __name__ == "__main__": + main()