diff --git a/kaldifeat/csrc/feature-spectrogram.h b/kaldifeat/csrc/feature-spectrogram.h index c4febe6..3665f4a 100644 --- a/kaldifeat/csrc/feature-spectrogram.h +++ b/kaldifeat/csrc/feature-spectrogram.h @@ -36,13 +36,12 @@ struct SpectrogramOptions { std::string ToString() const { std::ostringstream os; - os << "frame_opts: \n"; - os << frame_opts << "\n"; - - os << "energy_floor: " << energy_floor << "\n"; - os << "raw_energy: " << raw_energy << "\n"; - // os << "return_raw_fft: " << return_raw_fft << "\n"; - os << "device: " << device << "\n"; + os << "SpectrogramOptions("; + os << "frame_opts=" << frame_opts.ToString() << ", "; + os << "energy_floor=" << energy_floor << ", "; + os << "raw_energy=" << (raw_energy ? "True" : "False") << ", "; + os << "return_raw_fft=" << (return_raw_fft ? "True" : "False") << ", "; + os << "device=\"" << device << "\")"; return os.str(); } }; diff --git a/kaldifeat/python/csrc/feature-spectrogram.cc b/kaldifeat/python/csrc/feature-spectrogram.cc index 24b156b..aaf3d78 100644 --- a/kaldifeat/python/csrc/feature-spectrogram.cc +++ b/kaldifeat/python/csrc/feature-spectrogram.cc @@ -15,7 +15,27 @@ namespace kaldifeat { static void PybindSpectrogramOptions(py::module &m) { using PyClass = SpectrogramOptions; py::class_(m, "SpectrogramOptions") - .def(py::init<>()) + .def(py::init([](const FrameExtractionOptions &frame_opts = + FrameExtractionOptions(), + float energy_floor = 0.0, bool raw_energy = true, + bool return_raw_fft = false, + py::object device = py::str( + "cpu")) -> std::unique_ptr { + auto opts = std::make_unique(); + opts->frame_opts = frame_opts; + opts->energy_floor = energy_floor; + opts->raw_energy = raw_energy; + opts->return_raw_fft = return_raw_fft; + + std::string s = static_cast(device); + opts->device = torch::Device(s); + + return opts; + }), + py::arg("frame_opts") = FrameExtractionOptions(), + py::arg("energy_floor") = 0.0, py::arg("raw_energy") = true, + py::arg("return_raw_fft") = false, + py::arg("device") = py::str("cpu")) .def_readwrite("frame_opts", &PyClass::frame_opts) .def_readwrite("energy_floor", &PyClass::energy_floor) .def_readwrite("raw_energy", &PyClass::raw_energy) diff --git a/kaldifeat/python/tests/test_spectrogram_options.py b/kaldifeat/python/tests/test_spectrogram_options.py index 34c8849..7a4fd1a 100755 --- a/kaldifeat/python/tests/test_spectrogram_options.py +++ b/kaldifeat/python/tests/test_spectrogram_options.py @@ -12,6 +12,7 @@ import kaldifeat def test_default(): opts = kaldifeat.SpectrogramOptions() + print(opts) assert opts.frame_opts.samp_freq == 16000 assert opts.frame_opts.frame_shift_ms == 10.0 @@ -30,7 +31,8 @@ def test_default(): def test_set_get(): - opts = kaldifeat.SpectrogramOptions() + opts = kaldifeat.SpectrogramOptions(energy_floor=10) + assert opts.energy_floor == 10 opts.energy_floor = 1 assert opts.energy_floor == 1 @@ -138,6 +140,30 @@ def test_pickle(): assert str(opts) == str(opts2) +def test_device(): + opts = kaldifeat.SpectrogramOptions(device="cpu") + assert opts.device == torch.device("cpu") + + opts = kaldifeat.SpectrogramOptions(device="cuda") + assert opts.device == torch.device("cuda") + + opts = kaldifeat.SpectrogramOptions(device="cuda:1") + assert opts.device == torch.device("cuda:1") + print(opts) + + opts = kaldifeat.SpectrogramOptions(device=torch.device("cpu")) + assert opts.device == torch.device("cpu") + + opts = kaldifeat.SpectrogramOptions(device=torch.device("cuda")) + assert opts.device == torch.device("cuda") + + opts = kaldifeat.SpectrogramOptions(device=torch.device("cuda:3")) + assert opts.device == torch.device("cuda:3") + + opts = kaldifeat.SpectrogramOptions(device=torch.device("cuda", 2)) + assert opts.device == torch.device("cuda", 2) + + def main(): test_default() test_set_get() @@ -146,6 +172,7 @@ def main(): test_from_dict_partial() test_from_dict_full_and_as_dict() test_pickle() + test_device() if __name__ == "__main__":