diff --git a/kaldifeat/csrc/feature-common.h b/kaldifeat/csrc/feature-common.h index 485ea31..ad72d4e 100644 --- a/kaldifeat/csrc/feature-common.h +++ b/kaldifeat/csrc/feature-common.h @@ -42,14 +42,10 @@ class OfflineFeatureTpl { int32_t Dim() const { return computer_.Dim(); } // Copy constructor. - OfflineFeatureTpl(const OfflineFeatureTpl &other) - : computer_(other.computer_), - feature_window_function_(other.feature_window_function_) {} + OfflineFeatureTpl(const OfflineFeatureTpl &) = delete; + OfflineFeatureTpl &operator=(const OfflineFeatureTpl &) = delete; private: - // Disallow assignment. - OfflineFeatureTpl &operator=(const OfflineFeatureTpl &other); - F computer_; FeatureWindowFunction feature_window_function_; }; diff --git a/kaldifeat/csrc/feature-fbank.cc b/kaldifeat/csrc/feature-fbank.cc index 488e441..b5ad5a0 100644 --- a/kaldifeat/csrc/feature-fbank.cc +++ b/kaldifeat/csrc/feature-fbank.cc @@ -26,14 +26,6 @@ FbankComputer::FbankComputer(const FbankOptions &opts) : opts_(opts) { GetMelBanks(1.0f); } -FbankComputer::FbankComputer(const FbankComputer &other) - : opts_(other.opts_), - log_energy_floor_(other.log_energy_floor_), - mel_banks_(other.mel_banks_) { - for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter) - iter->second = new MelBanks(*(iter->second)); -} - FbankComputer::~FbankComputer() { for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter) delete iter->second; diff --git a/kaldifeat/csrc/feature-fbank.h b/kaldifeat/csrc/feature-fbank.h index f937d05..d3d686a 100644 --- a/kaldifeat/csrc/feature-fbank.h +++ b/kaldifeat/csrc/feature-fbank.h @@ -68,8 +68,7 @@ class FbankComputer { ~FbankComputer(); FbankComputer &operator=(const FbankComputer &) = delete; - - FbankComputer(const FbankComputer &other); + FbankComputer(const FbankComputer &) = delete; int32_t Dim() const { return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0); diff --git a/kaldifeat/python/csrc/feature-fbank.cc b/kaldifeat/python/csrc/feature-fbank.cc index bd068e2..f8c0a78 100644 --- a/kaldifeat/python/csrc/feature-fbank.cc +++ b/kaldifeat/python/csrc/feature-fbank.cc @@ -22,6 +22,12 @@ void PybindFbankOptions(py::module &m) { .def("__str__", [](const FbankOptions &self) -> std::string { return self.ToString(); }); + + py::class_(m, "Fbank") + .def(py::init(), py::arg("opts")) + .def("dim", &Fbank::Dim) + .def("compute_features", &Fbank::ComputeFeatures, py::arg("wave"), + py::arg("vtln_warp")); } } // namespace kaldifeat diff --git a/kaldifeat/python/csrc/kaldifeat.cc b/kaldifeat/python/csrc/kaldifeat.cc index 80bc08e..3398f8d 100644 --- a/kaldifeat/python/csrc/kaldifeat.cc +++ b/kaldifeat/python/csrc/kaldifeat.cc @@ -14,14 +14,9 @@ namespace kaldifeat { -static torch::Tensor Compute(const torch::Tensor &wave, - const FbankOptions &fbank_opts) { - // TODO(fangjun): wrap Fbank to Python - - Fbank fbank(fbank_opts); +static torch::Tensor Compute(const torch::Tensor &wave, Fbank *fbank) { float vtln_warp = 1.0f; - - torch::Tensor ans = fbank.ComputeFeatures(wave, vtln_warp); + torch::Tensor ans = fbank->ComputeFeatures(wave, vtln_warp); return ans; } @@ -32,18 +27,18 @@ PYBIND11_MODULE(_kaldifeat, m) { PybindMelBanksOptions(m); PybindFbankOptions(m); - m.def("compute", &Compute, py::arg("wave"), py::arg("fbank_opts")); + m.def("compute", &Compute, py::arg("wave"), py::arg("fbank")); // It verifies that the reimplementation produces the same output // as kaldi using default parameters with dither disabled. m.def( "_compute_with_elapsed_time", // for benchmark only [](const torch::Tensor &wave, - const FbankOptions &fbank_opts) -> std::pair { + Fbank *fbank) -> std::pair { std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); - torch::Tensor ans = Compute(wave, fbank_opts); + torch::Tensor ans = Compute(wave, fbank); std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); @@ -55,7 +50,7 @@ PYBIND11_MODULE(_kaldifeat, m) { return std::make_pair(ans, elapsed_seconds); }, - py::arg("wave"), py::arg("fbank_opts")); + py::arg("wave"), py::arg("fbank")); } } // namespace kaldifeat diff --git a/kaldifeat/python/tests/test_kaldifeat.py b/kaldifeat/python/tests/test_kaldifeat.py index 7e8c07d..3e0c24e 100755 --- a/kaldifeat/python/tests/test_kaldifeat.py +++ b/kaldifeat/python/tests/test_kaldifeat.py @@ -58,11 +58,11 @@ def read_wave() -> torch.Tensor: def test_and_benchmark_default_parameters(): fbank_opts = _kaldifeat.FbankOptions() fbank_opts.frame_opts.dither = 0 + fbank = _kaldifeat.Fbank(fbank_opts) data = read_wave() - ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time( - data, fbank_opts) + ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(data, fbank) expected = read_ark_txt() assert torch.allclose(ans, expected, rtol=1e-3) @@ -74,10 +74,11 @@ def test_use_energy_htk_compat_true(): fbank_opts.frame_opts.dither = 0 fbank_opts.use_energy = True fbank_opts.htk_compat = True + fbank = _kaldifeat.Fbank(fbank_opts) data = read_wave() - ans = _kaldifeat.compute(data, fbank_opts) + ans = _kaldifeat.compute(data, fbank) # ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:abc.scp ark,t:abc.txt # the first 3 rows are: @@ -95,10 +96,11 @@ def test_use_energy_htk_compat_false(): fbank_opts.frame_opts.dither = 0 fbank_opts.use_energy = True fbank_opts.htk_compat = False + fbank = _kaldifeat.Fbank(fbank_opts) data = read_wave() - ans = _kaldifeat.compute(data, fbank_opts) + ans = _kaldifeat.compute(data, fbank) # ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=0 scp:abc.scp ark,t:abc.txt # the first 3 rows are: