Add from_dict and as_dict for FrameExtractionOptions.

This commit is contained in:
Fangjun Kuang 2021-10-14 22:02:35 +08:00
parent 3ed1686424
commit 2ff0142455
5 changed files with 88 additions and 1 deletions

View File

@ -7,6 +7,7 @@ pybind11_add_module(_kaldifeat
feature-window.cc
kaldifeat.cc
mel-computations.cc
utils.cc
)
target_link_libraries(_kaldifeat PRIVATE kaldifeat_core)
if(UNIX AND NOT APPLE)

View File

@ -7,6 +7,7 @@
#include <string>
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/python/csrc/utils.h"
namespace kaldifeat {
@ -26,6 +27,14 @@ static void PybindFrameExtractionOptions(py::module &m) {
&FrameExtractionOptions::round_to_power_of_two)
.def_readwrite("blackman_coeff", &FrameExtractionOptions::blackman_coeff)
.def_readwrite("snip_edges", &FrameExtractionOptions::snip_edges)
.def("as_dict",
[](const FrameExtractionOptions &self) -> py::dict {
return AsDict(self);
})
.def_static("from_dict",
[](py::dict dict) -> FrameExtractionOptions {
return FrameExtractionOptionsFromDict(dict);
})
#if 0
.def_readwrite("allow_downsample",
&FrameExtractionOptions::allow_downsample)

View File

@ -0,0 +1,47 @@
// kaldifeat/python/csrc/utils.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#include "kaldifeat/python/csrc/utils.h"
#include "kaldifeat/csrc/feature-window.h"
#define FROM_DICT(type, key) opts.key = py::type(dict[#key])
#define AS_DICT(key) dict[#key] = opts.key
namespace kaldifeat {
FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict) {
FrameExtractionOptions opts;
FROM_DICT(float_, samp_freq);
FROM_DICT(float_, frame_shift_ms);
FROM_DICT(float_, frame_length_ms);
FROM_DICT(float_, dither);
FROM_DICT(float_, preemph_coeff);
FROM_DICT(bool_, remove_dc_offset);
FROM_DICT(str, window_type);
FROM_DICT(bool_, round_to_power_of_two);
FROM_DICT(float_, blackman_coeff);
FROM_DICT(bool_, snip_edges);
return opts;
}
py::dict AsDict(const FrameExtractionOptions &opts) {
py::dict dict;
AS_DICT(samp_freq);
AS_DICT(frame_shift_ms);
AS_DICT(frame_length_ms);
AS_DICT(dither);
AS_DICT(preemph_coeff);
AS_DICT(remove_dc_offset);
AS_DICT(window_type);
AS_DICT(round_to_power_of_two);
AS_DICT(blackman_coeff);
AS_DICT(snip_edges);
return dict;
}
#undef FROM_DICT
#undef AS_DICT
} // namespace kaldifeat

View File

@ -0,0 +1,18 @@
// kaldifeat/python/csrc/utils.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#ifndef KALDIFEAT_PYTHON_CSRC_UTILS_H_
#define KALDIFEAT_PYTHON_CSRC_UTILS_H_
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/python/csrc/kaldifeat.h"
namespace kaldifeat {
FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict);
py::dict AsDict(const FrameExtractionOptions &opts);
} // namespace kaldifeat
#endif // KALDIFEAT_PYTHON_CSRC_UTILS_H_

View File

@ -28,7 +28,19 @@ def test_frame_extraction_options():
opts.round_to_power_of_two = False
opts.blackman_coeff = 0.422
opts.snip_edges = False
print(opts)
opts_dict = opts.as_dict()
for key, value in opts_dict.items():
assert value == getattr(opts, key)
opts2 = kaldifeat.FrameExtractionOptions.from_dict(opts_dict)
for key, value in opts_dict.items():
assert value == getattr(opts2, key)
assert str(opts) == str(opts2)
assert opts_dict == opts2.as_dict()
def test_mel_banks_options():