From be45ac0ae6d7fa1a75b2d09a9ace4590202f8c01 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 15 Oct 2021 14:14:01 +0800 Subject: [PATCH] Support FbankOptions. --- kaldifeat/python/csrc/feature-fbank.cc | 8 +- kaldifeat/python/csrc/utils.cc | 44 +++++ kaldifeat/python/csrc/utils.h | 4 + kaldifeat/python/tests/CMakeLists.txt | 9 +- kaldifeat/python/tests/test_fbank_options.py | 189 +++++++++++++++++++ kaldifeat/python/tests/test_options.py | 35 ---- 6 files changed, 249 insertions(+), 40 deletions(-) create mode 100755 kaldifeat/python/tests/test_fbank_options.py diff --git a/kaldifeat/python/csrc/feature-fbank.cc b/kaldifeat/python/csrc/feature-fbank.cc index 1cca393..8d8e8f0 100644 --- a/kaldifeat/python/csrc/feature-fbank.cc +++ b/kaldifeat/python/csrc/feature-fbank.cc @@ -7,6 +7,7 @@ #include #include "kaldifeat/csrc/feature-fbank.h" +#include "kaldifeat/python/csrc/utils.h" namespace kaldifeat { @@ -33,7 +34,12 @@ static void PybindFbankOptions(py::module &m) { self.device = torch::Device(s); }) .def("__str__", - [](const PyClass &self) -> std::string { return self.ToString(); }); + [](const PyClass &self) -> std::string { return self.ToString(); }) + .def("as_dict", + [](const PyClass &self) -> py::dict { return AsDict(self); }) + .def_static("from_dict", [](py::dict dict) -> PyClass { + return FbankOptionsFromDict(dict); + }); } static void PybindFbank(py::module &m) { diff --git a/kaldifeat/python/csrc/utils.cc b/kaldifeat/python/csrc/utils.cc index b298e61..81439b1 100644 --- a/kaldifeat/python/csrc/utils.cc +++ b/kaldifeat/python/csrc/utils.cc @@ -45,6 +45,7 @@ py::dict AsDict(const FrameExtractionOptions &opts) { AS_DICT(round_to_power_of_two); AS_DICT(blackman_coeff); AS_DICT(snip_edges); + return dict; } @@ -75,6 +76,49 @@ py::dict AsDict(const MelBanksOptions &opts) { return dict; } +FbankOptions FbankOptionsFromDict(py::dict dict) { + FbankOptions opts; + + if (dict.contains("frame_opts")) { + opts.frame_opts = FrameExtractionOptionsFromDict(dict["frame_opts"]); + } + + if (dict.contains("mel_opts")) { + opts.mel_opts = MelBanksOptionsFromDict(dict["mel_opts"]); + } + + FROM_DICT(bool_, use_energy); + FROM_DICT(float_, energy_floor); + FROM_DICT(bool_, raw_energy); + FROM_DICT(bool_, htk_compat); + FROM_DICT(bool_, use_log_fbank); + FROM_DICT(bool_, use_power); + + if (dict.contains("device")) { + opts.device = torch::Device(std::string(py::str(dict["device"]))); + } + + return opts; +} + +py::dict AsDict(const FbankOptions &opts) { + py::dict dict; + + dict["frame_opts"] = AsDict(opts.frame_opts); + dict["mel_opts"] = AsDict(opts.mel_opts); + AS_DICT(use_energy); + AS_DICT(energy_floor); + AS_DICT(raw_energy); + AS_DICT(htk_compat); + AS_DICT(use_log_fbank); + AS_DICT(use_power); + + auto torch_device = py::module_::import("torch").attr("device"); + dict["device"] = torch_device(opts.device.str()); + + return dict; +} + #undef FROM_DICT #undef AS_DICT diff --git a/kaldifeat/python/csrc/utils.h b/kaldifeat/python/csrc/utils.h index 5b39402..b683e63 100644 --- a/kaldifeat/python/csrc/utils.h +++ b/kaldifeat/python/csrc/utils.h @@ -5,6 +5,7 @@ #ifndef KALDIFEAT_PYTHON_CSRC_UTILS_H_ #define KALDIFEAT_PYTHON_CSRC_UTILS_H_ +#include "kaldifeat/csrc/feature-fbank.h" #include "kaldifeat/csrc/feature-window.h" #include "kaldifeat/csrc/mel-computations.h" #include "kaldifeat/python/csrc/kaldifeat.h" @@ -29,6 +30,9 @@ py::dict AsDict(const FrameExtractionOptions &opts); MelBanksOptions MelBanksOptionsFromDict(py::dict dict); py::dict AsDict(const MelBanksOptions &opts); +FbankOptions FbankOptionsFromDict(py::dict dict); +py::dict AsDict(const FbankOptions &opts); + } // namespace kaldifeat #endif // KALDIFEAT_PYTHON_CSRC_UTILS_H_ diff --git a/kaldifeat/python/tests/CMakeLists.txt b/kaldifeat/python/tests/CMakeLists.txt index 93e81ae..02bad90 100644 --- a/kaldifeat/python/tests/CMakeLists.txt +++ b/kaldifeat/python/tests/CMakeLists.txt @@ -18,12 +18,13 @@ endfunction() # please sort the files in alphabetic order set(py_test_files test_fbank.py - test_mfcc.py - test_plp.py - test_spectrogram.py - test_options.py + test_fbank_options.py test_frame_extraction_options.py test_mel_bank_options.py + test_mfcc.py + test_options.py + test_plp.py + test_spectrogram.py ) foreach(source IN LISTS py_test_files) diff --git a/kaldifeat/python/tests/test_fbank_options.py b/kaldifeat/python/tests/test_fbank_options.py new file mode 100755 index 0000000..562bcac --- /dev/null +++ b/kaldifeat/python/tests/test_fbank_options.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + + +import torch + +import kaldifeat + + +def test_default(): + opts = kaldifeat.FbankOptions() + assert opts.frame_opts.samp_freq == 16000 + assert opts.frame_opts.frame_shift_ms == 10.0 + assert opts.frame_opts.frame_length_ms == 25.0 + assert opts.frame_opts.dither == 1.0 + assert abs(opts.frame_opts.preemph_coeff - 0.97) < 1e-6 + assert opts.frame_opts.remove_dc_offset is True + assert opts.frame_opts.window_type == "povey" + assert opts.frame_opts.round_to_power_of_two is True + assert abs(opts.frame_opts.blackman_coeff - 0.42) < 1e-6 + assert opts.frame_opts.snip_edges is True + + assert opts.mel_opts.num_bins == 23 + assert opts.mel_opts.low_freq == 20 + assert opts.mel_opts.high_freq == 0 + assert opts.mel_opts.vtln_low == 100 + assert opts.mel_opts.vtln_high == -500 + assert opts.mel_opts.debug_mel is False + assert opts.mel_opts.htk_mode is False + + assert opts.use_energy is False + assert opts.energy_floor == 0.0 + assert opts.raw_energy is True + assert opts.htk_compat is False + assert opts.use_log_fbank is True + assert opts.use_power is True + assert opts.device.type == "cpu" + + +def test_set_get(): + opts = kaldifeat.FbankOptions() + opts.use_energy = True + assert opts.use_energy is True + + opts.energy_floor = 1 + assert opts.energy_floor == 1 + + opts.raw_energy = False + assert opts.raw_energy is False + + opts.htk_compat = True + assert opts.htk_compat is True + + opts.use_log_fbank = False + assert opts.use_log_fbank is False + + opts.use_power = False + assert opts.use_power is False + + opts.device = torch.device("cuda", 1) + assert opts.device.type == "cuda" + assert opts.device.index == 1 + + +def test_set_get_frame_opts(): + opts = kaldifeat.FbankOptions() + + opts.frame_opts.samp_freq = 44100 + assert opts.frame_opts.samp_freq == 44100 + + opts.frame_opts.frame_shift_ms = 20.5 + assert opts.frame_opts.frame_shift_ms == 20.5 + + opts.frame_opts.frame_length_ms = 1 + assert opts.frame_opts.frame_length_ms == 1 + + opts.frame_opts.dither = 0.5 + assert opts.frame_opts.dither == 0.5 + + opts.frame_opts.preemph_coeff = 0.25 + assert opts.frame_opts.preemph_coeff == 0.25 + + opts.frame_opts.remove_dc_offset = False + assert opts.frame_opts.remove_dc_offset is False + + opts.frame_opts.window_type = "hanning" + assert opts.frame_opts.window_type == "hanning" + + opts.frame_opts.round_to_power_of_two = False + assert opts.frame_opts.round_to_power_of_two is False + + opts.frame_opts.blackman_coeff = 0.25 + assert opts.frame_opts.blackman_coeff == 0.25 + + opts.frame_opts.snip_edges = False + assert opts.frame_opts.snip_edges is False + + +def test_set_get_mel_opts(): + opts = kaldifeat.FbankOptions() + + opts.mel_opts.num_bins = 100 + assert opts.mel_opts.num_bins == 100 + + opts.mel_opts.low_freq = 22 + assert opts.mel_opts.low_freq == 22 + + opts.mel_opts.high_freq = 1 + assert opts.mel_opts.high_freq == 1 + + opts.mel_opts.vtln_low = 101 + assert opts.mel_opts.vtln_low == 101 + + opts.mel_opts.vtln_high = -100 + assert opts.mel_opts.vtln_high == -100 + + opts.mel_opts.debug_mel = True + assert opts.mel_opts.debug_mel is True + + opts.mel_opts.htk_mode = True + assert opts.mel_opts.htk_mode is True + + +def test_from_empty_dict(): + opts = kaldifeat.FbankOptions.from_dict({}) + opts2 = kaldifeat.FbankOptions() + + assert str(opts) == str(opts2) + + +def test_from_dict_partial(): + d = { + "energy_floor": 10.5, + "htk_compat": True, + "mel_opts": {"num_bins": 80, "vtln_low": 1}, + "frame_opts": {"window_type": "hanning"}, + } + opts = kaldifeat.FbankOptions.from_dict(d) + assert opts.energy_floor == 10.5 + assert opts.htk_compat is True + assert opts.mel_opts.num_bins == 80 + assert opts.mel_opts.vtln_low == 1 + assert opts.frame_opts.window_type == "hanning" + + opts = kaldifeat.MelBanksOptions.from_dict(d) + + +def test_from_dict_full_and_as_dict(): + opts = kaldifeat.FbankOptions() + opts.htk_compat = True + opts.mel_opts.num_bins = 80 + opts.frame_opts.samp_freq = 10 + + d = opts.as_dict() + assert d["htk_compat"] is True + assert d["mel_opts"]["num_bins"] == 80 + assert d["frame_opts"]["samp_freq"] == 10 + + mel_opts = kaldifeat.MelBanksOptions() + mel_opts.num_bins = 80 + assert d["mel_opts"] == mel_opts.as_dict() + + frame_opts = kaldifeat.FrameExtractionOptions() + frame_opts.samp_freq = 10 + assert d["frame_opts"] == frame_opts.as_dict() + + opts2 = kaldifeat.FbankOptions.from_dict(d) + assert str(opts2) == str(opts) + + d["htk_compat"] = False + d["device"] = torch.device("cuda", 2) + opts3 = kaldifeat.FbankOptions.from_dict(d) + assert opts3.htk_compat is False + assert opts3.device == torch.device("cuda", 2) + + +def main(): + test_default() + test_set_get() + test_set_get_frame_opts() + test_set_get_mel_opts() + test_from_empty_dict() + test_from_dict_partial() + test_from_dict_full_and_as_dict() + + +if __name__ == "__main__": + main() diff --git a/kaldifeat/python/tests/test_options.py b/kaldifeat/python/tests/test_options.py index 92375f5..2016f3f 100755 --- a/kaldifeat/python/tests/test_options.py +++ b/kaldifeat/python/tests/test_options.py @@ -16,40 +16,6 @@ sys.path.insert(0, f"{kaldi_feat_dir}/build/lib") import kaldifeat -def test_fbank_options(): - opts = kaldifeat.FbankOptions() - frame_opts = opts.frame_opts - mel_opts = opts.mel_opts - - opts.energy_floor = 0 - opts.htk_compat = False - opts.raw_energy = True - opts.use_energy = False - opts.use_log_fbank = True - opts.use_power = True - opts.device = torch.device("cuda", 0) - - frame_opts.blackman_coeff = 0.42 - frame_opts.dither = 1 - frame_opts.frame_length_ms = 25 - frame_opts.frame_shift_ms = 10 - frame_opts.preemph_coeff = 0.97 - frame_opts.remove_dc_offset = True - frame_opts.round_to_power_of_two = True - frame_opts.samp_freq = 16000 - frame_opts.snip_edges = True - frame_opts.window_type = "povey" - - mel_opts.debug_mel = True - mel_opts.high_freq = 0 - mel_opts.low_freq = 20 - mel_opts.num_bins = 23 - mel_opts.vtln_high = -500 - mel_opts.vtln_low = 100 - - print(opts) - - def test_mfcc_options(): opts = kaldifeat.MfccOptions() frame_opts = opts.frame_opts @@ -141,7 +107,6 @@ def test_plp_options(): def main(): - test_fbank_options() test_mfcc_options() test_spectogram_options() test_plp_options()