From 518846a74e866e5cb37e5866600385930bfa6fa2 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 15 Oct 2021 17:36:44 +0800 Subject: [PATCH] Add SpectrogramOptions. --- kaldifeat/python/csrc/feature-spectrogram.cc | 8 +- kaldifeat/python/csrc/utils.cc | 33 +++++ kaldifeat/python/csrc/utils.h | 4 + kaldifeat/python/tests/CMakeLists.txt | 1 + .../python/tests/test_spectrogram_options.py | 134 ++++++++++++++++++ 5 files changed, 179 insertions(+), 1 deletion(-) create mode 100755 kaldifeat/python/tests/test_spectrogram_options.py diff --git a/kaldifeat/python/csrc/feature-spectrogram.cc b/kaldifeat/python/csrc/feature-spectrogram.cc index ef39338..f752ebe 100644 --- a/kaldifeat/python/csrc/feature-spectrogram.cc +++ b/kaldifeat/python/csrc/feature-spectrogram.cc @@ -7,6 +7,7 @@ #include #include "kaldifeat/csrc/feature-spectrogram.h" +#include "kaldifeat/python/csrc/utils.h" namespace kaldifeat { @@ -30,7 +31,12 @@ static void PybindSpectrogramOptions(py::module &m) { self.device = torch::Device(s); }) .def("__str__", - [](const PyClass &self) -> std::string { return self.ToString(); }); + [](const PyClass &self) -> std::string { return self.ToString(); }) + .def("as_dict", + [](const PyClass &self) -> py::dict { return AsDict(self); }) + .def_static("from_dict", [](py::dict dict) -> PyClass { + return SpectrogramOptionsFromDict(dict); + }); } static void PybindSpectrogram(py::module &m) { diff --git a/kaldifeat/python/csrc/utils.cc b/kaldifeat/python/csrc/utils.cc index 85f9787..19d96dd 100644 --- a/kaldifeat/python/csrc/utils.cc +++ b/kaldifeat/python/csrc/utils.cc @@ -149,6 +149,7 @@ py::dict AsDict(const MfccOptions &opts) { dict["frame_opts"] = AsDict(opts.frame_opts); dict["mel_opts"] = AsDict(opts.mel_opts); + AS_DICT(num_ceps); AS_DICT(use_energy); AS_DICT(energy_floor); @@ -162,6 +163,38 @@ py::dict AsDict(const MfccOptions &opts) { return dict; } +SpectrogramOptions SpectrogramOptionsFromDict(py::dict dict) { + SpectrogramOptions opts; + + if (dict.contains("frame_opts")) { + opts.frame_opts = FrameExtractionOptionsFromDict(dict["frame_opts"]); + } + + FROM_DICT(float_, energy_floor); + FROM_DICT(bool_, raw_energy); + // FROM_DICT(bool_, return_raw_fft); + + if (dict.contains("device")) { + opts.device = torch::Device(std::string(py::str(dict["device"]))); + } + + return opts; +} + +py::dict AsDict(const SpectrogramOptions &opts) { + py::dict dict; + + dict["frame_opts"] = AsDict(opts.frame_opts); + + AS_DICT(energy_floor); + AS_DICT(raw_energy); + + auto torch_device = py::module_::import("torch").attr("device"); + dict["device"] = torch_device(opts.device.str()); + + return dict; +} + #undef FROM_DICT #undef AS_DICT diff --git a/kaldifeat/python/csrc/utils.h b/kaldifeat/python/csrc/utils.h index 5ddc0d7..6a9f271 100644 --- a/kaldifeat/python/csrc/utils.h +++ b/kaldifeat/python/csrc/utils.h @@ -7,6 +7,7 @@ #include "kaldifeat/csrc/feature-fbank.h" #include "kaldifeat/csrc/feature-mfcc.h" +#include "kaldifeat/csrc/feature-spectrogram.h" #include "kaldifeat/csrc/feature-window.h" #include "kaldifeat/csrc/mel-computations.h" #include "kaldifeat/python/csrc/kaldifeat.h" @@ -37,6 +38,9 @@ py::dict AsDict(const FbankOptions &opts); MfccOptions MfccOptionsFromDict(py::dict dict); py::dict AsDict(const MfccOptions &opts); +SpectrogramOptions SpectrogramOptionsFromDict(py::dict dict); +py::dict AsDict(const SpectrogramOptions &opts); + } // namespace kaldifeat #endif // KALDIFEAT_PYTHON_CSRC_UTILS_H_ diff --git a/kaldifeat/python/tests/CMakeLists.txt b/kaldifeat/python/tests/CMakeLists.txt index b5372c1..8a1133e 100644 --- a/kaldifeat/python/tests/CMakeLists.txt +++ b/kaldifeat/python/tests/CMakeLists.txt @@ -26,6 +26,7 @@ set(py_test_files test_options.py test_plp.py test_spectrogram.py + test_spectrogram_options.py ) foreach(source IN LISTS py_test_files) diff --git a/kaldifeat/python/tests/test_spectrogram_options.py b/kaldifeat/python/tests/test_spectrogram_options.py new file mode 100755 index 0000000..d830300 --- /dev/null +++ b/kaldifeat/python/tests/test_spectrogram_options.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + + +import torch + +import kaldifeat + + +def test_default(): + opts = kaldifeat.SpectrogramOptions() + + assert opts.frame_opts.samp_freq == 16000 + assert opts.frame_opts.frame_shift_ms == 10.0 + assert opts.frame_opts.frame_length_ms == 25.0 + assert opts.frame_opts.dither == 1.0 + assert abs(opts.frame_opts.preemph_coeff - 0.97) < 1e-6 + assert opts.frame_opts.remove_dc_offset is True + assert opts.frame_opts.window_type == "povey" + assert opts.frame_opts.round_to_power_of_two is True + assert abs(opts.frame_opts.blackman_coeff - 0.42) < 1e-6 + assert opts.frame_opts.snip_edges is True + + assert opts.energy_floor == 0 + assert opts.raw_energy is True + assert opts.device.type == "cpu" + + +def test_set_get(): + opts = kaldifeat.SpectrogramOptions() + + opts.energy_floor = 1 + assert opts.energy_floor == 1 + + opts.raw_energy = False + assert opts.raw_energy is False + + opts.device = torch.device("cuda", 1) + assert opts.device.type == "cuda" + assert opts.device.index == 1 + + +def test_set_get_frame_opts(): + opts = kaldifeat.SpectrogramOptions() + + opts.frame_opts.samp_freq = 44100 + assert opts.frame_opts.samp_freq == 44100 + + opts.frame_opts.frame_shift_ms = 20.5 + assert opts.frame_opts.frame_shift_ms == 20.5 + + opts.frame_opts.frame_length_ms = 1 + assert opts.frame_opts.frame_length_ms == 1 + + opts.frame_opts.dither = 0.5 + assert opts.frame_opts.dither == 0.5 + + opts.frame_opts.preemph_coeff = 0.25 + assert opts.frame_opts.preemph_coeff == 0.25 + + opts.frame_opts.remove_dc_offset = False + assert opts.frame_opts.remove_dc_offset is False + + opts.frame_opts.window_type = "hanning" + assert opts.frame_opts.window_type == "hanning" + + opts.frame_opts.round_to_power_of_two = False + assert opts.frame_opts.round_to_power_of_two is False + + opts.frame_opts.blackman_coeff = 0.25 + assert opts.frame_opts.blackman_coeff == 0.25 + + opts.frame_opts.snip_edges = False + assert opts.frame_opts.snip_edges is False + + +def test_from_empty_dict(): + opts = kaldifeat.SpectrogramOptions.from_dict({}) + opts2 = kaldifeat.SpectrogramOptions() + + assert str(opts) == str(opts2) + + +def test_from_dict_partial(): + d = { + "energy_floor": 10.5, + "frame_opts": {"window_type": "hanning"}, + "device": "cuda:2", + } + opts = kaldifeat.SpectrogramOptions.from_dict(d) + assert opts.energy_floor == 10.5 + assert opts.device == torch.device("cuda", 2) + assert opts.frame_opts.window_type == "hanning" + + frame_opts = kaldifeat.FrameExtractionOptions.from_dict(d["frame_opts"]) + assert str(opts.frame_opts) == str(frame_opts) + + +def test_from_dict_full_and_as_dict(): + opts = kaldifeat.SpectrogramOptions() + opts.frame_opts.samp_freq = 12 + opts.device = "cuda:3" + + d = opts.as_dict() + assert d["frame_opts"]["samp_freq"] == 12 + assert d["device"] == torch.device("cuda:3") + + frame_opts = kaldifeat.FrameExtractionOptions() + frame_opts.samp_freq = 12 + assert d["frame_opts"] == frame_opts.as_dict() + + opts2 = kaldifeat.SpectrogramOptions.from_dict(d) + assert str(opts2) == str(opts) + + d["device"] = torch.device("cuda", 10) + opts3 = kaldifeat.SpectrogramOptions.from_dict(d) + assert opts3.device == torch.device("cuda", 10) + + opts.device = "cuda:10" + assert str(opts3) == str(opts) + + +def main(): + test_default() + test_set_get() + test_set_get_frame_opts() + test_from_empty_dict() + test_from_dict_partial() + test_from_dict_full_and_as_dict() + + +if __name__ == "__main__": + main()