diff --git a/kaldifeat/CMakeLists.txt b/kaldifeat/CMakeLists.txt index 86735ca..c70d00c 100644 --- a/kaldifeat/CMakeLists.txt +++ b/kaldifeat/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(csrc) +add_subdirectory(python) diff --git a/kaldifeat/csrc/feature-common-inl.h b/kaldifeat/csrc/feature-common-inl.h index 5416cf3..e5fa801 100644 --- a/kaldifeat/csrc/feature-common-inl.h +++ b/kaldifeat/csrc/feature-common-inl.h @@ -18,10 +18,10 @@ torch::Tensor OfflineFeatureTpl::ComputeFeatures(const torch::Tensor &wave, int32_t rows_out = NumFrames(wave.sizes()[0], computer_.GetFrameOptions()); int32_t cols_out = computer_.Dim(); - const FrameExtractionOptions &frame_opts = - computer_.GetFrameOptions().frame_opts; + const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions(); - torch::Tensor strided_input = GetStrided(wave, frame_opts); + // TODO(fangjun): avoid clone + torch::Tensor strided_input = GetStrided(wave, frame_opts).clone(); if (frame_opts.dither != 0) strided_input = Dither(strided_input, frame_opts.dither); @@ -62,24 +62,21 @@ torch::Tensor OfflineFeatureTpl::ComputeFeatures(const torch::Tensor &wave, torch::indexing::None)}) = right - frame_opts.preemph_coeff * current; - strided_input.index({"...", 0}) *= frame_opts.preemph_coeff; + strided_input.index({"...", 0}) *= 1 - frame_opts.preemph_coeff; } strided_input = feature_window_function_.Apply(strided_input); -#if 0 - Vector window; // windowed waveform. - bool use_raw_log_energy = computer_.NeedRawLogEnergy(); - for (int32 r = 0; r < rows_out; r++) { // r is frame index. - BaseFloat raw_log_energy = 0.0; - ExtractWindow(0, wave, r, computer_.GetFrameOptions(), - feature_window_function_, &window, - (use_raw_log_energy ? &raw_log_energy : NULL)); + int32_t padding = frame_opts.PaddedWindowSize() - strided_input.sizes()[1]; - SubVector output_row(*output, r); - computer_.Compute(raw_log_energy, vtln_warp, &window, &output_row); + if (padding > 0) { + strided_input = torch::nn::functional::pad( + strided_input, torch::nn::functional::PadFuncOptions({0, padding}) + .mode(torch::kConstant) + .value(0)); } -#endif + + return computer_.Compute(log_energy_pre_window, vtln_warp, strided_input); } } // namespace kaldifeat diff --git a/kaldifeat/csrc/feature-fbank.cc b/kaldifeat/csrc/feature-fbank.cc index 5b1c623..6569eba 100644 --- a/kaldifeat/csrc/feature-fbank.cc +++ b/kaldifeat/csrc/feature-fbank.cc @@ -79,6 +79,7 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy, SubVector mel_energies(*feature, mel_offset, opts_.mel_opts.num_bins); #endif + // TODO(fangjun): remove the last column of spectrum torch::Tensor mel_energies = mel_banks.Compute(spectrum); if (opts_.use_log_fbank) { // Avoid log of zero (which should be prevented anyway by dithering). diff --git a/kaldifeat/csrc/mel-computations.cc b/kaldifeat/csrc/mel-computations.cc index 0d4ca0b..557811e 100644 --- a/kaldifeat/csrc/mel-computations.cc +++ b/kaldifeat/csrc/mel-computations.cc @@ -131,7 +131,9 @@ MelBanks::MelBanks(const MelBanksOptions &opts, << " and vtln-high " << vtln_high << ", versus " << "low-freq " << low_freq << " and high-freq " << high_freq; - bins_mat_ = torch::zeros({num_bins, num_fft_bins}, torch::kFloat); + // TODO(fangjun): remove the last column of the power spectrum + // and set the number of columns to num_fft_bins instead of num_fft_bins + 1 + bins_mat_ = torch::zeros({num_bins, num_fft_bins + 1}, torch::kFloat); int32_t stride = bins_mat_.strides()[0]; for (int32_t bin = 0; bin < num_bins; ++bin) { diff --git a/kaldifeat/csrc/mel-computations.h b/kaldifeat/csrc/mel-computations.h index b64c27c..b89d74b 100644 --- a/kaldifeat/csrc/mel-computations.h +++ b/kaldifeat/csrc/mel-computations.h @@ -63,6 +63,9 @@ class MelBanks { torch::Tensor Compute(const torch::Tensor &spectrum) const; + // for debug only + const torch::Tensor &GetBinsMat() const { return bins_mat_; } + private: // A 2-D matrix of shape [num_bins, num_fft_bins] torch::Tensor bins_mat_; diff --git a/kaldifeat/csrc/test_kaldifeat.cc b/kaldifeat/csrc/test_kaldifeat.cc index a9bc12c..1513300 100644 --- a/kaldifeat/csrc/test_kaldifeat.cc +++ b/kaldifeat/csrc/test_kaldifeat.cc @@ -26,12 +26,17 @@ static void TestPreemph() { std::cout << d << "\n"; } -int main() { +static void TestPad() { torch::Tensor a = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat); - torch::Tensor b = torch::arange(1, 4).to(torch::kFloat).unsqueeze(0); + torch::Tensor b = torch::nn::functional::pad( + a, torch::nn::functional::PadFuncOptions({0, 3}) + .mode(torch::kConstant) + .value(0)); std::cout << a << "\n"; std::cout << b << "\n"; - std::cout << a * b << "\n"; +} +int main() { + TestPad(); return 0; } diff --git a/kaldifeat/python/CMakeLists.txt b/kaldifeat/python/CMakeLists.txt new file mode 100644 index 0000000..86735ca --- /dev/null +++ b/kaldifeat/python/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(csrc) diff --git a/kaldifeat/python/csrc/CMakeLists.txt b/kaldifeat/python/csrc/CMakeLists.txt new file mode 100644 index 0000000..8c27ab7 --- /dev/null +++ b/kaldifeat/python/csrc/CMakeLists.txt @@ -0,0 +1,4 @@ +add_definitions(-DTORCH_API_INCLUDE_EXTENSION_H) +pybind11_add_module(_kaldifeat kaldifeat.cc) +target_link_libraries(_kaldifeat PRIVATE kaldifeat_core) +target_link_libraries(_kaldifeat PRIVATE ${TORCH_DIR}/lib/libtorch_python.so) diff --git a/kaldifeat/python/csrc/kaldifeat.cc b/kaldifeat/python/csrc/kaldifeat.cc new file mode 100644 index 0000000..2ebc982 --- /dev/null +++ b/kaldifeat/python/csrc/kaldifeat.cc @@ -0,0 +1,28 @@ +// kaldifeat/python/csrc/kaldifeat.cc +// +// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +#include "kaldifeat/python/csrc/kaldifeat.h" + +#include "kaldifeat/csrc/feature-fbank.h" +#include "torch/torch.h" + +namespace kaldifeat { + +PYBIND11_MODULE(_kaldifeat, m) { + m.doc() = "Python wrapper for kaldifeat"; + + m.def("test", [](const torch::Tensor &tensor) -> torch::Tensor { + std::cout << "size: " << tensor.sizes() << "\n"; + FbankOptions fbank_opts; + fbank_opts.frame_opts.dither = 0.0f; + + Fbank fbank(fbank_opts); + float vtln_warp = 1.0f; + + torch::Tensor ans = fbank.ComputeFeatures(tensor, vtln_warp); + return ans; + }); +} + +} // namespace kaldifeat diff --git a/kaldifeat/python/csrc/kaldifeat.h b/kaldifeat/python/csrc/kaldifeat.h new file mode 100644 index 0000000..11865e5 --- /dev/null +++ b/kaldifeat/python/csrc/kaldifeat.h @@ -0,0 +1,11 @@ +// kaldifeat/python/csrc/kaldifeat.h +// +// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +#ifndef KALDIFEAT_PYTHON_CSRC_KALDIFEAT_H_ +#define KALDIFEAT_PYTHON_CSRC_KALDIFEAT_H_ + +#include "pybind11/pybind11.h" +namespace py = pybind11; + +#endif // KALDIFEAT_PYTHON_CSRC_KALDIFEAT_H_ diff --git a/kaldifeat/python/tests/test.py b/kaldifeat/python/tests/test.py new file mode 100755 index 0000000..f8953d8 --- /dev/null +++ b/kaldifeat/python/tests/test.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +import sys + +sys.path.insert(0, '/root/fangjun/open-source/kaldifeat/build/lib') + +import torch +import numpy as np +import soundfile as sf +import _kaldifeat + + +def main(): + # sox -n -r 16000 -b 16 abc.wav synth 1 sine 100 + filename = '/root/fangjun/open-source/kaldi/src/featbin/abc.wav' + with sf.SoundFile(filename) as sf_desc: + sampling_rate = sf_desc.samplerate + assert sampling_rate == 16000 + a = sf_desc.read(dtype=np.float32, always_2d=False) + a *= 32768 + tensor = torch.from_numpy(a) + ans = _kaldifeat.test(tensor) + torch.set_printoptions(profile="full") + print(ans.shape) + print(ans) + + +if __name__ == '__main__': + main()