Change the way how SpectrogramOptions is displayed

This commit is contained in:
Fangjun Kuang 2022-12-03 12:40:32 +08:00
parent 5b97eeadb5
commit 32948f9556
3 changed files with 55 additions and 9 deletions

View File

@ -36,13 +36,12 @@ struct SpectrogramOptions {
std::string ToString() const {
std::ostringstream os;
os << "frame_opts: \n";
os << frame_opts << "\n";
os << "energy_floor: " << energy_floor << "\n";
os << "raw_energy: " << raw_energy << "\n";
// os << "return_raw_fft: " << return_raw_fft << "\n";
os << "device: " << device << "\n";
os << "SpectrogramOptions(";
os << "frame_opts=" << frame_opts.ToString() << ", ";
os << "energy_floor=" << energy_floor << ", ";
os << "raw_energy=" << (raw_energy ? "True" : "False") << ", ";
os << "return_raw_fft=" << (return_raw_fft ? "True" : "False") << ", ";
os << "device=\"" << device << "\")";
return os.str();
}
};

View File

@ -15,7 +15,27 @@ namespace kaldifeat {
static void PybindSpectrogramOptions(py::module &m) {
using PyClass = SpectrogramOptions;
py::class_<PyClass>(m, "SpectrogramOptions")
.def(py::init<>())
.def(py::init([](const FrameExtractionOptions &frame_opts =
FrameExtractionOptions(),
float energy_floor = 0.0, bool raw_energy = true,
bool return_raw_fft = false,
py::object device = py::str(
"cpu")) -> std::unique_ptr<SpectrogramOptions> {
auto opts = std::make_unique<SpectrogramOptions>();
opts->frame_opts = frame_opts;
opts->energy_floor = energy_floor;
opts->raw_energy = raw_energy;
opts->return_raw_fft = return_raw_fft;
std::string s = static_cast<py::str>(device);
opts->device = torch::Device(s);
return opts;
}),
py::arg("frame_opts") = FrameExtractionOptions(),
py::arg("energy_floor") = 0.0, py::arg("raw_energy") = true,
py::arg("return_raw_fft") = false,
py::arg("device") = py::str("cpu"))
.def_readwrite("frame_opts", &PyClass::frame_opts)
.def_readwrite("energy_floor", &PyClass::energy_floor)
.def_readwrite("raw_energy", &PyClass::raw_energy)

View File

@ -12,6 +12,7 @@ import kaldifeat
def test_default():
opts = kaldifeat.SpectrogramOptions()
print(opts)
assert opts.frame_opts.samp_freq == 16000
assert opts.frame_opts.frame_shift_ms == 10.0
@ -30,7 +31,8 @@ def test_default():
def test_set_get():
opts = kaldifeat.SpectrogramOptions()
opts = kaldifeat.SpectrogramOptions(energy_floor=10)
assert opts.energy_floor == 10
opts.energy_floor = 1
assert opts.energy_floor == 1
@ -138,6 +140,30 @@ def test_pickle():
assert str(opts) == str(opts2)
def test_device():
opts = kaldifeat.SpectrogramOptions(device="cpu")
assert opts.device == torch.device("cpu")
opts = kaldifeat.SpectrogramOptions(device="cuda")
assert opts.device == torch.device("cuda")
opts = kaldifeat.SpectrogramOptions(device="cuda:1")
assert opts.device == torch.device("cuda:1")
print(opts)
opts = kaldifeat.SpectrogramOptions(device=torch.device("cpu"))
assert opts.device == torch.device("cpu")
opts = kaldifeat.SpectrogramOptions(device=torch.device("cuda"))
assert opts.device == torch.device("cuda")
opts = kaldifeat.SpectrogramOptions(device=torch.device("cuda:3"))
assert opts.device == torch.device("cuda:3")
opts = kaldifeat.SpectrogramOptions(device=torch.device("cuda", 2))
assert opts.device == torch.device("cuda", 2)
def main():
test_default()
test_set_get()
@ -146,6 +172,7 @@ def main():
test_from_dict_partial()
test_from_dict_full_and_as_dict()
test_pickle()
test_device()
if __name__ == "__main__":