diff --git a/kaldifeat/python/csrc/utils.cc b/kaldifeat/python/csrc/utils.cc index 0476491..c50a4a0 100644 --- a/kaldifeat/python/csrc/utils.cc +++ b/kaldifeat/python/csrc/utils.cc @@ -6,7 +6,11 @@ #include "kaldifeat/csrc/feature-window.h" -#define FROM_DICT(type, key) opts.key = py::type(dict[#key]) +#define FROM_DICT(type, key) \ + if (dict.contains(#key)) { \ + opts.key = py::type(dict[#key]); \ + } + #define AS_DICT(key) dict[#key] = opts.key namespace kaldifeat { diff --git a/kaldifeat/python/tests/CMakeLists.txt b/kaldifeat/python/tests/CMakeLists.txt index 9961b14..0a0cb08 100644 --- a/kaldifeat/python/tests/CMakeLists.txt +++ b/kaldifeat/python/tests/CMakeLists.txt @@ -22,6 +22,7 @@ set(py_test_files test_plp.py test_spectrogram.py test_options.py + test_frame_extraction_options.py ) foreach(source IN LISTS py_test_files) diff --git a/kaldifeat/python/tests/test_frame_extraction_options.py b/kaldifeat/python/tests/test_frame_extraction_options.py new file mode 100755 index 0000000..4fa90d9 --- /dev/null +++ b/kaldifeat/python/tests/test_frame_extraction_options.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +import kaldifeat + + +def test_default(): + opts = kaldifeat.FrameExtractionOptions() + assert opts.samp_freq == 16000 + assert opts.frame_shift_ms == 10.0 + assert opts.frame_length_ms == 25.0 + assert opts.dither == 1.0 + assert abs(opts.preemph_coeff - 0.97) < 1e-6 + assert opts.remove_dc_offset is True + assert opts.window_type == "povey" + assert opts.round_to_power_of_two is True + assert abs(opts.blackman_coeff - 0.42) < 1e-6 + assert opts.snip_edges is True + + +def test_set_get(): + opts = kaldifeat.FrameExtractionOptions() + opts.samp_freq = 44100 + assert opts.samp_freq == 44100 + + opts.frame_shift_ms = 20.5 + assert opts.frame_shift_ms == 20.5 + + opts.frame_length_ms = 1 + assert opts.frame_length_ms == 1 + + opts.dither = 0.5 + assert opts.dither == 0.5 + + opts.preemph_coeff = 0.25 + assert opts.preemph_coeff == 0.25 + + opts.remove_dc_offset = False + assert opts.remove_dc_offset is False + + opts.window_type = "hanning" + assert opts.window_type == "hanning" + + opts.round_to_power_of_two = False + assert opts.round_to_power_of_two is False + + opts.blackman_coeff = 0.25 + assert opts.blackman_coeff == 0.25 + + opts.snip_edges = False + assert opts.snip_edges is False + + +def test_from_empty_dict(): + opts = kaldifeat.FrameExtractionOptions.from_dict({}) + opts2 = kaldifeat.FrameExtractionOptions() + + assert str(opts) == str(opts2) + + +def test_from_dict_partial(): + d = {"samp_freq": 10, "frame_shift_ms": 2} + + opts = kaldifeat.FrameExtractionOptions.from_dict(d) + + opts2 = kaldifeat.FrameExtractionOptions() + assert str(opts) != str(opts2) + + opts2.samp_freq = 10 + assert str(opts) != str(opts2) + + opts2.frame_shift_ms = 2 + assert str(opts) == str(opts2) + + opts2.frame_shift_ms = 3 + assert str(opts) != str(opts2) + + +def test_from_dict_full_and_as_dict(): + opts = kaldifeat.FrameExtractionOptions() + opts.samp_freq = 20 + opts.frame_length_ms = 100 + + d = opts.as_dict() + for key, value in d.items(): + assert value == getattr(opts, key) + + opts2 = kaldifeat.FrameExtractionOptions.from_dict(d) + assert str(opts2) == str(opts) + + d["window_type"] = "hanning" + opts3 = kaldifeat.FrameExtractionOptions.from_dict(d) + assert opts3.window_type == "hanning" + + +def main(): + test_default() + test_set_get() + test_from_empty_dict() + test_from_dict_partial() + test_from_dict_full_and_as_dict() + + +if __name__ == "__main__": + main() diff --git a/kaldifeat/python/tests/test_options.py b/kaldifeat/python/tests/test_options.py index 96e2f63..2373779 100755 --- a/kaldifeat/python/tests/test_options.py +++ b/kaldifeat/python/tests/test_options.py @@ -16,33 +16,6 @@ sys.path.insert(0, f"{kaldi_feat_dir}/build/lib") import kaldifeat -def test_frame_extraction_options(): - opts = kaldifeat.FrameExtractionOptions() - opts.samp_freq = 220500 - opts.frame_shift_ms = 15 - opts.frame_length_ms = 40 - opts.dither = 0.1 - opts.preemph_coeff = 0.98 - opts.remove_dc_offset = False - opts.window_type = "hanning" - opts.round_to_power_of_two = False - opts.blackman_coeff = 0.422 - opts.snip_edges = False - - opts_dict = opts.as_dict() - for key, value in opts_dict.items(): - assert value == getattr(opts, key) - - opts2 = kaldifeat.FrameExtractionOptions.from_dict(opts_dict) - - for key, value in opts_dict.items(): - assert value == getattr(opts2, key) - - assert str(opts) == str(opts2) - - assert opts_dict == opts2.as_dict() - - def test_mel_banks_options(): opts = kaldifeat.MelBanksOptions() opts.num_bins = 23 @@ -180,7 +153,6 @@ def test_plp_options(): def main(): - test_frame_extraction_options() test_mel_banks_options() test_fbank_options() test_mfcc_options()