Support MelBankOptions.

This commit is contained in:
Fangjun Kuang 2021-10-15 13:07:29 +08:00
parent 876db53746
commit d2e093d2be
6 changed files with 148 additions and 14 deletions

View File

@ -7,6 +7,7 @@
#include <string>
#include "kaldifeat/csrc/mel-computations.h"
#include "kaldifeat/python/csrc/utils.h"
namespace kaldifeat {
@ -22,7 +23,12 @@ static void PybindMelBanksOptions(py::module &m) {
.def_readwrite("debug_mel", &PyClass::debug_mel)
.def_readwrite("htk_mode", &PyClass::htk_mode)
.def("__str__",
[](const PyClass &self) -> std::string { return self.ToString(); });
[](const PyClass &self) -> std::string { return self.ToString(); })
.def("as_dict",
[](const PyClass &self) -> py::dict { return AsDict(self); })
.def_static("from_dict", [](py::dict dict) -> PyClass {
return MelBanksOptionsFromDict(dict);
});
}
void PybindMelComputations(py::module &m) { PybindMelBanksOptions(m); }

View File

@ -17,6 +17,7 @@ namespace kaldifeat {
FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict) {
FrameExtractionOptions opts;
FROM_DICT(float_, samp_freq);
FROM_DICT(float_, frame_shift_ms);
FROM_DICT(float_, frame_length_ms);
@ -27,11 +28,13 @@ FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict) {
FROM_DICT(bool_, round_to_power_of_two);
FROM_DICT(float_, blackman_coeff);
FROM_DICT(bool_, snip_edges);
return opts;
}
py::dict AsDict(const FrameExtractionOptions &opts) {
py::dict dict;
AS_DICT(samp_freq);
AS_DICT(frame_shift_ms);
AS_DICT(frame_length_ms);
@ -45,6 +48,33 @@ py::dict AsDict(const FrameExtractionOptions &opts) {
return dict;
}
MelBanksOptions MelBanksOptionsFromDict(py::dict dict) {
MelBanksOptions opts;
FROM_DICT(int_, num_bins);
FROM_DICT(float_, low_freq);
FROM_DICT(float_, high_freq);
FROM_DICT(float_, vtln_low);
FROM_DICT(float_, vtln_high);
FROM_DICT(bool_, debug_mel);
FROM_DICT(bool_, htk_mode);
return opts;
}
py::dict AsDict(const MelBanksOptions &opts) {
py::dict dict;
AS_DICT(num_bins);
AS_DICT(low_freq);
AS_DICT(high_freq);
AS_DICT(vtln_low);
AS_DICT(vtln_high);
AS_DICT(debug_mel);
AS_DICT(htk_mode);
return dict;
}
#undef FROM_DICT
#undef AS_DICT

View File

@ -6,13 +6,29 @@
#define KALDIFEAT_PYTHON_CSRC_UTILS_H_
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/mel-computations.h"
#include "kaldifeat/python/csrc/kaldifeat.h"
/*
* This file contains code about `from_dict` and
* `to_dict` for various options in kaldifeat.
*
* Regarding `from_dict`, users don't need to provide
* all the fields in the options. If some fields
* are not provided, it just uses the default one.
*
* If the provided dict in `from_dict` is empty,
* all fields use their default values.
*/
namespace kaldifeat {
FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict);
py::dict AsDict(const FrameExtractionOptions &opts);
MelBanksOptions MelBanksOptionsFromDict(py::dict dict);
py::dict AsDict(const MelBanksOptions &opts);
} // namespace kaldifeat
#endif // KALDIFEAT_PYTHON_CSRC_UTILS_H_

View File

@ -23,6 +23,7 @@ set(py_test_files
test_spectrogram.py
test_options.py
test_frame_extraction_options.py
test_mel_bank_options.py
)
foreach(source IN LISTS py_test_files)

View File

@ -0,0 +1,94 @@
#!/usr/bin/env python3
#
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
import kaldifeat
def test_default():
opts = kaldifeat.MelBanksOptions()
assert opts.num_bins == 25
assert opts.low_freq == 20
assert opts.high_freq == 0
assert opts.vtln_low == 100
assert opts.vtln_high == -500
assert opts.debug_mel is False
assert opts.htk_mode is False
def test_set_get():
opts = kaldifeat.MelBanksOptions()
opts.num_bins = 100
assert opts.num_bins == 100
opts.low_freq = 22
assert opts.low_freq == 22
opts.high_freq = 1
assert opts.high_freq == 1
opts.vtln_low = 101
assert opts.vtln_low == 101
opts.vtln_high = -100
assert opts.vtln_high == -100
opts.debug_mel = True
assert opts.debug_mel is True
opts.htk_mode = True
assert opts.htk_mode is True
def test_from_empty_dict():
opts = kaldifeat.MelBanksOptions.from_dict({})
opts2 = kaldifeat.MelBanksOptions()
assert str(opts) == str(opts2)
def test_from_dict_partial():
d = {"num_bins": 10, "debug_mel": True}
opts = kaldifeat.MelBanksOptions.from_dict(d)
opts2 = kaldifeat.MelBanksOptions()
assert str(opts) != str(opts2)
opts2.num_bins = 10
assert str(opts) != str(opts2)
opts2.debug_mel = True
assert str(opts) == str(opts2)
opts2.debug_mel = False
assert str(opts) != str(opts2)
def test_from_dict_full_and_as_dict():
opts = kaldifeat.MelBanksOptions()
opts.num_bins = 80
opts.vtln_high = 2
d = opts.as_dict()
for key, value in d.items():
assert value == getattr(opts, key)
opts2 = kaldifeat.MelBanksOptions.from_dict(d)
assert str(opts2) == str(opts)
d["htk_mode"] = True
opts3 = kaldifeat.MelBanksOptions.from_dict(d)
assert opts3.htk_mode is True
def main():
test_default()
test_set_get()
test_from_empty_dict()
test_from_dict_partial()
test_from_dict_full_and_as_dict()
if __name__ == "__main__":
main()

View File

@ -16,18 +16,6 @@ sys.path.insert(0, f"{kaldi_feat_dir}/build/lib")
import kaldifeat
def test_mel_banks_options():
opts = kaldifeat.MelBanksOptions()
opts.num_bins = 23
opts.low_freq = 21
opts.high_freq = 8000
opts.vtln_low = 101
opts.vtln_high = -501
opts.debug_mel = True
opts.htk_mode = True
print(opts)
def test_fbank_options():
opts = kaldifeat.FbankOptions()
frame_opts = opts.frame_opts
@ -153,7 +141,6 @@ def test_plp_options():
def main():
test_mel_banks_options()
test_fbank_options()
test_mfcc_options()
test_spectogram_options()