From 0dacee71be418ddc477ad314ba2e3337f4600b67 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 28 Feb 2021 00:20:02 +0800 Subject: [PATCH] minor fixes. --- kaldifeat/python/csrc/feature-window.cc | 2 +- kaldifeat/python/tests/test_kaldifeat.py | 47 +++++++++++++----------- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/kaldifeat/python/csrc/feature-window.cc b/kaldifeat/python/csrc/feature-window.cc index a35de15..92dedde 100644 --- a/kaldifeat/python/csrc/feature-window.cc +++ b/kaldifeat/python/csrc/feature-window.cc @@ -38,7 +38,7 @@ void PybindFrameExtractionOptions(py::module &m) { m.def("num_frames", &NumFrames, py::arg("num_samples"), py::arg("opts"), py::arg("flush") = true); - m.def("get_strided", &GetStrided, py::arg("ave"), py::arg("opts")); + m.def("get_strided", &GetStrided, py::arg("wave"), py::arg("opts")); } } // namespace kaldifeat diff --git a/kaldifeat/python/tests/test_kaldifeat.py b/kaldifeat/python/tests/test_kaldifeat.py index b031593..84491ce 100755 --- a/kaldifeat/python/tests/test_kaldifeat.py +++ b/kaldifeat/python/tests/test_kaldifeat.py @@ -135,32 +135,37 @@ def test_use_energy_htk_compat_false(): def test_compute_batch(): - data1 = read_wave() - data2 = read_wave() + devices = [torch.device('cpu')] + if torch.cuda.is_available(): + devices.append(torch.device('cuda', 0)) - data = [data1, data2] - fbank_opts = _kaldifeat.FbankOptions() - fbank_opts.frame_opts.dither = 0 - fbank = _kaldifeat.Fbank(fbank_opts) + for device in devices: + data1 = read_wave().to(device) + data2 = read_wave().to(device) - def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]: - num_frames = [ - _kaldifeat.num_frames(w.numel(), fbank_opts.frame_opts) - for w in waves - ] + fbank_opts = _kaldifeat.FbankOptions() + fbank_opts.frame_opts.dither = 0 + fbank_opts.set_device(device) + fbank = _kaldifeat.Fbank(fbank_opts) - strided = [ - _kaldifeat.get_strided(w, fbank_opts.frame_opts) for w in waves - ] - strided = torch.cat(strided, dim=0) + def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]: + num_frames = [ + _kaldifeat.num_frames(w.numel(), fbank_opts.frame_opts) + for w in waves + ] - features = _kaldifeat.compute(strided, fbank) - feature1 = features[:num_frames[0]] - feature2 = features[num_frames[0]:] - return [feature1, feature2] + strided = [ + _kaldifeat.get_strided(w, fbank_opts.frame_opts) for w in waves + ] + strided = torch.cat(strided, dim=0) - feature1, feature2 = impl([data1, data2]) - assert torch.allclose(feature1, feature2) + features = _kaldifeat.compute(strided, fbank) + feature1 = features[:num_frames[0]] + feature2 = features[num_frames[0]:] + return [feature1, feature2] + + feature1, feature2 = impl([data1, data2]) + assert torch.allclose(feature1, feature2) def main():