minor fixes.

This commit is contained in:
Fangjun Kuang 2021-02-28 00:20:02 +08:00
parent efcd8ab92c
commit 0dacee71be
2 changed files with 27 additions and 22 deletions

View File

@ -38,7 +38,7 @@ void PybindFrameExtractionOptions(py::module &m) {
m.def("num_frames", &NumFrames, py::arg("num_samples"), py::arg("opts"), m.def("num_frames", &NumFrames, py::arg("num_samples"), py::arg("opts"),
py::arg("flush") = true); py::arg("flush") = true);
m.def("get_strided", &GetStrided, py::arg("ave"), py::arg("opts")); m.def("get_strided", &GetStrided, py::arg("wave"), py::arg("opts"));
} }
} // namespace kaldifeat } // namespace kaldifeat

View File

@ -135,12 +135,17 @@ def test_use_energy_htk_compat_false():
def test_compute_batch(): def test_compute_batch():
data1 = read_wave() devices = [torch.device('cpu')]
data2 = read_wave() if torch.cuda.is_available():
devices.append(torch.device('cuda', 0))
for device in devices:
data1 = read_wave().to(device)
data2 = read_wave().to(device)
data = [data1, data2]
fbank_opts = _kaldifeat.FbankOptions() fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0 fbank_opts.frame_opts.dither = 0
fbank_opts.set_device(device)
fbank = _kaldifeat.Fbank(fbank_opts) fbank = _kaldifeat.Fbank(fbank_opts)
def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]: def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]: