diff --git a/kaldifeat/csrc/feature-mfcc.h b/kaldifeat/csrc/feature-mfcc.h index 2c00b7f..2d45b2c 100644 --- a/kaldifeat/csrc/feature-mfcc.h +++ b/kaldifeat/csrc/feature-mfcc.h @@ -53,20 +53,18 @@ struct MfccOptions { std::string ToString() const { std::ostringstream os; - os << "frame_opts: \n"; - os << frame_opts << "\n"; - os << "\n"; + os << "MfccOptions("; + os << "frame_opts=" << frame_opts.ToString() << ", "; + os << "mel_opts=" << mel_opts.ToString() << ", "; - os << "mel_opts: \n"; - os << mel_opts << "\n"; + os << "num_ceps=" << num_ceps << ", "; + os << "use_energy=" << (use_energy ? "True" : "False") << ", "; + os << "energy_floor=" << energy_floor << ", "; + os << "raw_energy=" << (raw_energy ? "True" : "False") << ", "; + os << "cepstral_lifter=" << cepstral_lifter << ", "; + os << "htk_compat=" << (htk_compat ? "True" : "False") << ", "; + os << "device=\"" << device << "\")"; - os << "num_ceps: " << num_ceps << "\n"; - os << "use_energy: " << use_energy << "\n"; - os << "energy_floor: " << energy_floor << "\n"; - os << "raw_energy: " << raw_energy << "\n"; - os << "cepstral_lifter: " << cepstral_lifter << "\n"; - os << "htk_compat: " << htk_compat << "\n"; - os << "device: " << device << "\n"; return os.str(); } }; diff --git a/kaldifeat/python/csrc/feature-fbank.cc b/kaldifeat/python/csrc/feature-fbank.cc index 07dbed8..a7ed09a 100644 --- a/kaldifeat/python/csrc/feature-fbank.cc +++ b/kaldifeat/python/csrc/feature-fbank.cc @@ -16,14 +16,14 @@ static void PybindFbankOptions(py::module &m) { using PyClass = FbankOptions; py::class_(m, "FbankOptions") .def(py::init<>()) - .def(py::init([](const FrameExtractionOptions &frame_opts = + .def(py::init([](const MelBanksOptions &mel_opts, + const FrameExtractionOptions &frame_opts = FrameExtractionOptions(), bool use_energy = false, float energy_floor = 0.0f, bool raw_energy = true, bool htk_compat = false, bool use_log_fbank = true, bool use_power = true, - py::object device = py::str("cpu"), - const MelBanksOptions &mel_opts) - -> std::unique_ptr { + py::object device = + py::str("cpu")) -> std::unique_ptr { auto opts = std::make_unique(); opts->frame_opts = frame_opts; opts->mel_opts = mel_opts; @@ -39,11 +39,12 @@ static void PybindFbankOptions(py::module &m) { return opts; }), + py::arg("mel_opts"), py::arg("frame_opts") = FrameExtractionOptions(), py::arg("use_energy") = false, py::arg("energy_floor") = 0.0f, py::arg("raw_energy") = true, py::arg("htk_compat") = false, py::arg("use_log_fbank") = true, py::arg("use_power") = true, - py::arg("device") = py::str("cpu"), py::arg("mel_opts")) + py::arg("device") = py::str("cpu")) .def_readwrite("frame_opts", &PyClass::frame_opts) .def_readwrite("mel_opts", &PyClass::mel_opts) .def_readwrite("use_energy", &PyClass::use_energy) diff --git a/kaldifeat/python/csrc/feature-mfcc.cc b/kaldifeat/python/csrc/feature-mfcc.cc index fe893cb..44c200d 100644 --- a/kaldifeat/python/csrc/feature-mfcc.cc +++ b/kaldifeat/python/csrc/feature-mfcc.cc @@ -16,6 +16,35 @@ void PybindMfccOptions(py::module &m) { using PyClass = MfccOptions; py::class_(m, "MfccOptions") .def(py::init<>()) + .def(py::init([](const MelBanksOptions &mel_opts, + const FrameExtractionOptions &frame_opts = + FrameExtractionOptions(), + int32_t num_ceps = 13, bool use_energy = true, + float energy_floor = 0.0, bool raw_energy = true, + float cepstral_lifter = 22.0, bool htk_compat = false, + py::object device = + py::str("cpu")) -> std::unique_ptr { + auto opts = std::make_unique(); + opts->frame_opts = frame_opts; + opts->mel_opts = mel_opts; + opts->num_ceps = num_ceps; + opts->use_energy = use_energy; + opts->energy_floor = energy_floor; + opts->raw_energy = raw_energy; + opts->cepstral_lifter = cepstral_lifter; + opts->htk_compat = htk_compat; + + std::string s = static_cast(device); + opts->device = torch::Device(s); + + return opts; + }), + py::arg("mel_opts"), + py::arg("frame_opts") = FrameExtractionOptions(), + py::arg("num_ceps") = 13, py::arg("use_energy") = true, + py::arg("energy_floor") = 0.0f, py::arg("raw_energy") = true, + py::arg("cepstral_lifter") = 22.0, py::arg("htk_compat") = false, + py::arg("device") = py::str("cpu")) .def_readwrite("frame_opts", &PyClass::frame_opts) .def_readwrite("mel_opts", &PyClass::mel_opts) .def_readwrite("num_ceps", &PyClass::num_ceps) diff --git a/kaldifeat/python/tests/test_mfcc_options.py b/kaldifeat/python/tests/test_mfcc_options.py index cef03ab..c650f46 100755 --- a/kaldifeat/python/tests/test_mfcc_options.py +++ b/kaldifeat/python/tests/test_mfcc_options.py @@ -12,6 +12,7 @@ import kaldifeat def test_default(): opts = kaldifeat.MfccOptions() + print(opts) assert opts.frame_opts.samp_freq == 16000 assert opts.frame_opts.frame_shift_ms == 10.0