First working version.

It produces the same output as kaldi's `compute-fbank-feats`
using default parameters with `--dither=0`.
This commit is contained in:
Fangjun Kuang 2021-02-26 23:23:23 +08:00
parent 9bd6ee0c5f
commit e930dc176f
11 changed files with 103 additions and 19 deletions

View File

@ -1 +1,2 @@
add_subdirectory(csrc)
add_subdirectory(python)

View File

@ -18,10 +18,10 @@ torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
int32_t rows_out = NumFrames(wave.sizes()[0], computer_.GetFrameOptions());
int32_t cols_out = computer_.Dim();
const FrameExtractionOptions &frame_opts =
computer_.GetFrameOptions().frame_opts;
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
torch::Tensor strided_input = GetStrided(wave, frame_opts);
// TODO(fangjun): avoid clone
torch::Tensor strided_input = GetStrided(wave, frame_opts).clone();
if (frame_opts.dither != 0)
strided_input = Dither(strided_input, frame_opts.dither);
@ -62,24 +62,21 @@ torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
torch::indexing::None)}) =
right - frame_opts.preemph_coeff * current;
strided_input.index({"...", 0}) *= frame_opts.preemph_coeff;
strided_input.index({"...", 0}) *= 1 - frame_opts.preemph_coeff;
}
strided_input = feature_window_function_.Apply(strided_input);
#if 0
Vector<BaseFloat> window; // windowed waveform.
bool use_raw_log_energy = computer_.NeedRawLogEnergy();
for (int32 r = 0; r < rows_out; r++) { // r is frame index.
BaseFloat raw_log_energy = 0.0;
ExtractWindow(0, wave, r, computer_.GetFrameOptions(),
feature_window_function_, &window,
(use_raw_log_energy ? &raw_log_energy : NULL));
int32_t padding = frame_opts.PaddedWindowSize() - strided_input.sizes()[1];
SubVector<BaseFloat> output_row(*output, r);
computer_.Compute(raw_log_energy, vtln_warp, &window, &output_row);
if (padding > 0) {
strided_input = torch::nn::functional::pad(
strided_input, torch::nn::functional::PadFuncOptions({0, padding})
.mode(torch::kConstant)
.value(0));
}
#endif
return computer_.Compute(log_energy_pre_window, vtln_warp, strided_input);
}
} // namespace kaldifeat

View File

@ -79,6 +79,7 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
SubVector<float> mel_energies(*feature, mel_offset, opts_.mel_opts.num_bins);
#endif
// TODO(fangjun): remove the last column of spectrum
torch::Tensor mel_energies = mel_banks.Compute(spectrum);
if (opts_.use_log_fbank) {
// Avoid log of zero (which should be prevented anyway by dithering).

View File

@ -131,7 +131,9 @@ MelBanks::MelBanks(const MelBanksOptions &opts,
<< " and vtln-high " << vtln_high << ", versus "
<< "low-freq " << low_freq << " and high-freq " << high_freq;
bins_mat_ = torch::zeros({num_bins, num_fft_bins}, torch::kFloat);
// TODO(fangjun): remove the last column of the power spectrum
// and set the number of columns to num_fft_bins instead of num_fft_bins + 1
bins_mat_ = torch::zeros({num_bins, num_fft_bins + 1}, torch::kFloat);
int32_t stride = bins_mat_.strides()[0];
for (int32_t bin = 0; bin < num_bins; ++bin) {

View File

@ -63,6 +63,9 @@ class MelBanks {
torch::Tensor Compute(const torch::Tensor &spectrum) const;
// for debug only
const torch::Tensor &GetBinsMat() const { return bins_mat_; }
private:
// A 2-D matrix of shape [num_bins, num_fft_bins]
torch::Tensor bins_mat_;

View File

@ -26,12 +26,17 @@ static void TestPreemph() {
std::cout << d << "\n";
}
int main() {
static void TestPad() {
torch::Tensor a = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat);
torch::Tensor b = torch::arange(1, 4).to(torch::kFloat).unsqueeze(0);
torch::Tensor b = torch::nn::functional::pad(
a, torch::nn::functional::PadFuncOptions({0, 3})
.mode(torch::kConstant)
.value(0));
std::cout << a << "\n";
std::cout << b << "\n";
std::cout << a * b << "\n";
}
int main() {
TestPad();
return 0;
}

View File

@ -0,0 +1 @@
add_subdirectory(csrc)

View File

@ -0,0 +1,4 @@
add_definitions(-DTORCH_API_INCLUDE_EXTENSION_H)
pybind11_add_module(_kaldifeat kaldifeat.cc)
target_link_libraries(_kaldifeat PRIVATE kaldifeat_core)
target_link_libraries(_kaldifeat PRIVATE ${TORCH_DIR}/lib/libtorch_python.so)

View File

@ -0,0 +1,28 @@
// kaldifeat/python/csrc/kaldifeat.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#include "kaldifeat/python/csrc/kaldifeat.h"
#include "kaldifeat/csrc/feature-fbank.h"
#include "torch/torch.h"
namespace kaldifeat {
PYBIND11_MODULE(_kaldifeat, m) {
m.doc() = "Python wrapper for kaldifeat";
m.def("test", [](const torch::Tensor &tensor) -> torch::Tensor {
std::cout << "size: " << tensor.sizes() << "\n";
FbankOptions fbank_opts;
fbank_opts.frame_opts.dither = 0.0f;
Fbank fbank(fbank_opts);
float vtln_warp = 1.0f;
torch::Tensor ans = fbank.ComputeFeatures(tensor, vtln_warp);
return ans;
});
}
} // namespace kaldifeat

View File

@ -0,0 +1,11 @@
// kaldifeat/python/csrc/kaldifeat.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#ifndef KALDIFEAT_PYTHON_CSRC_KALDIFEAT_H_
#define KALDIFEAT_PYTHON_CSRC_KALDIFEAT_H_
#include "pybind11/pybind11.h"
namespace py = pybind11;
#endif // KALDIFEAT_PYTHON_CSRC_KALDIFEAT_H_

31
kaldifeat/python/tests/test.py Executable file
View File

@ -0,0 +1,31 @@
#!/usr/bin/env python3
#
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
import sys
sys.path.insert(0, '/root/fangjun/open-source/kaldifeat/build/lib')
import torch
import numpy as np
import soundfile as sf
import _kaldifeat
def main():
# sox -n -r 16000 -b 16 abc.wav synth 1 sine 100
filename = '/root/fangjun/open-source/kaldi/src/featbin/abc.wav'
with sf.SoundFile(filename) as sf_desc:
sampling_rate = sf_desc.samplerate
assert sampling_rate == 16000
a = sf_desc.read(dtype=np.float32, always_2d=False)
a *= 32768
tensor = torch.from_numpy(a)
ans = _kaldifeat.test(tensor)
torch.set_printoptions(profile="full")
print(ans.shape)
print(ans)
if __name__ == '__main__':
main()