From efcd8ab92c6854408174719a4d5a24e9c0dd0dbd Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 28 Feb 2021 00:07:58 +0800 Subject: [PATCH] Support batch processing. --- kaldifeat/csrc/feature-common-inl.h | 11 +++++--- kaldifeat/csrc/feature-common.h | 1 + kaldifeat/csrc/feature-fbank.h | 2 ++ kaldifeat/csrc/test_kaldifeat.cc | 2 ++ kaldifeat/python/csrc/feature-fbank.cc | 2 ++ kaldifeat/python/csrc/feature-window.cc | 5 ++++ kaldifeat/python/tests/test_kaldifeat.py | 33 ++++++++++++++++++++++++ 7 files changed, 53 insertions(+), 3 deletions(-) diff --git a/kaldifeat/csrc/feature-common-inl.h b/kaldifeat/csrc/feature-common-inl.h index 0c9e3c4..ef73c53 100644 --- a/kaldifeat/csrc/feature-common-inl.h +++ b/kaldifeat/csrc/feature-common-inl.h @@ -14,11 +14,16 @@ namespace kaldifeat { template torch::Tensor OfflineFeatureTpl::ComputeFeatures(const torch::Tensor &wave, float vtln_warp) { - KALDIFEAT_ASSERT(wave.dim() == 1); - const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions(); - torch::Tensor strided_input = GetStrided(wave, frame_opts); + torch::Tensor strided_input; + if (wave.dim() == 1) { + strided_input = GetStrided(wave, frame_opts); + } else { + KALDIFEAT_ASSERT(wave.dim() == 2); + KALDIFEAT_ASSERT(wave.sizes()[1] == frame_opts.WindowSize()); + strided_input = wave; + } if (frame_opts.dither != 0.0f) { strided_input = Dither(strided_input, frame_opts.dither); diff --git a/kaldifeat/csrc/feature-common.h b/kaldifeat/csrc/feature-common.h index 1949eed..f2c1799 100644 --- a/kaldifeat/csrc/feature-common.h +++ b/kaldifeat/csrc/feature-common.h @@ -40,6 +40,7 @@ class OfflineFeatureTpl { torch::Tensor ComputeFeatures(const torch::Tensor &wave, float vtln_warp); int32_t Dim() const { return computer_.Dim(); } + const Options &GetOptions() const { return computer_.GetOptions(); } // Copy constructor. OfflineFeatureTpl(const OfflineFeatureTpl &) = delete; diff --git a/kaldifeat/csrc/feature-fbank.h b/kaldifeat/csrc/feature-fbank.h index 01db278..fc40359 100644 --- a/kaldifeat/csrc/feature-fbank.h +++ b/kaldifeat/csrc/feature-fbank.h @@ -85,6 +85,8 @@ class FbankComputer { return opts_.frame_opts; } + const FbankOptions &GetOptions() const { return opts_; } + // signal_raw_log_energy is log_energy_pre_window, which is not empty // iff NeedRawLogEnergy() returns true. torch::Tensor Compute(torch::Tensor signal_raw_log_energy, float vtln_warp, diff --git a/kaldifeat/csrc/test_kaldifeat.cc b/kaldifeat/csrc/test_kaldifeat.cc index 528a7cf..bd2d43e 100644 --- a/kaldifeat/csrc/test_kaldifeat.cc +++ b/kaldifeat/csrc/test_kaldifeat.cc @@ -67,10 +67,12 @@ static void TestCat() { torch::Tensor b = torch::arange(0, 2).reshape({2, 1}).to(torch::kFloat) * 0.1; torch::Tensor c = torch::cat({a, b}, 1); torch::Tensor d = torch::cat({b, a}, 1); + torch::Tensor e = torch::cat({a, a}, 0); std::cout << a << "\n"; std::cout << b << "\n"; std::cout << c << "\n"; std::cout << d << "\n"; + std::cout << e << "\n"; } int main() { diff --git a/kaldifeat/python/csrc/feature-fbank.cc b/kaldifeat/python/csrc/feature-fbank.cc index 29b7f78..5122605 100644 --- a/kaldifeat/python/csrc/feature-fbank.cc +++ b/kaldifeat/python/csrc/feature-fbank.cc @@ -42,6 +42,8 @@ void PybindFbankOptions(py::module &m) { py::class_(m, "Fbank") .def(py::init(), py::arg("opts")) .def("dim", &Fbank::Dim) + .def("options", &Fbank::GetOptions, + py::return_value_policy::reference_internal) .def("compute_features", &Fbank::ComputeFeatures, py::arg("wave"), py::arg("vtln_warp")); } diff --git a/kaldifeat/python/csrc/feature-window.cc b/kaldifeat/python/csrc/feature-window.cc index bc3267f..a35de15 100644 --- a/kaldifeat/python/csrc/feature-window.cc +++ b/kaldifeat/python/csrc/feature-window.cc @@ -34,6 +34,11 @@ void PybindFrameExtractionOptions(py::module &m) { .def("__str__", [](const FrameExtractionOptions &self) -> std::string { return self.ToString(); }); + + m.def("num_frames", &NumFrames, py::arg("num_samples"), py::arg("opts"), + py::arg("flush") = true); + + m.def("get_strided", &GetStrided, py::arg("ave"), py::arg("opts")); } } // namespace kaldifeat diff --git a/kaldifeat/python/tests/test_kaldifeat.py b/kaldifeat/python/tests/test_kaldifeat.py index 0c970cc..b031593 100755 --- a/kaldifeat/python/tests/test_kaldifeat.py +++ b/kaldifeat/python/tests/test_kaldifeat.py @@ -6,6 +6,8 @@ from pathlib import Path cur_dir = Path(__file__).resolve().parent kaldi_feat_dir = cur_dir.parent.parent.parent +from typing import List + import sys sys.path.insert(0, f'{kaldi_feat_dir}/build/lib') @@ -132,11 +134,42 @@ def test_use_energy_htk_compat_false(): assert torch.allclose(ans[:3, :].cpu(), expected, rtol=1e-3) +def test_compute_batch(): + data1 = read_wave() + data2 = read_wave() + + data = [data1, data2] + fbank_opts = _kaldifeat.FbankOptions() + fbank_opts.frame_opts.dither = 0 + fbank = _kaldifeat.Fbank(fbank_opts) + + def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]: + num_frames = [ + _kaldifeat.num_frames(w.numel(), fbank_opts.frame_opts) + for w in waves + ] + + strided = [ + _kaldifeat.get_strided(w, fbank_opts.frame_opts) for w in waves + ] + strided = torch.cat(strided, dim=0) + + features = _kaldifeat.compute(strided, fbank) + feature1 = features[:num_frames[0]] + feature2 = features[num_frames[0]:] + return [feature1, feature2] + + feature1, feature2 = impl([data1, data2]) + assert torch.allclose(feature1, feature2) + + def main(): test_and_benchmark_default_parameters() test_use_energy_htk_compat_true() test_use_energy_htk_compat_false() + test_compute_batch() + if __name__ == '__main__': main()