mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-13 12:02:19 +00:00
First working version.
It produces the same output as kaldi's `compute-fbank-feats` using default parameters with `--dither=0`.
This commit is contained in:
parent
9bd6ee0c5f
commit
e930dc176f
@ -1 +1,2 @@
|
|||||||
add_subdirectory(csrc)
|
add_subdirectory(csrc)
|
||||||
|
add_subdirectory(python)
|
||||||
|
@ -18,10 +18,10 @@ torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
|
|||||||
int32_t rows_out = NumFrames(wave.sizes()[0], computer_.GetFrameOptions());
|
int32_t rows_out = NumFrames(wave.sizes()[0], computer_.GetFrameOptions());
|
||||||
int32_t cols_out = computer_.Dim();
|
int32_t cols_out = computer_.Dim();
|
||||||
|
|
||||||
const FrameExtractionOptions &frame_opts =
|
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
|
||||||
computer_.GetFrameOptions().frame_opts;
|
|
||||||
|
|
||||||
torch::Tensor strided_input = GetStrided(wave, frame_opts);
|
// TODO(fangjun): avoid clone
|
||||||
|
torch::Tensor strided_input = GetStrided(wave, frame_opts).clone();
|
||||||
|
|
||||||
if (frame_opts.dither != 0)
|
if (frame_opts.dither != 0)
|
||||||
strided_input = Dither(strided_input, frame_opts.dither);
|
strided_input = Dither(strided_input, frame_opts.dither);
|
||||||
@ -62,24 +62,21 @@ torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
|
|||||||
torch::indexing::None)}) =
|
torch::indexing::None)}) =
|
||||||
right - frame_opts.preemph_coeff * current;
|
right - frame_opts.preemph_coeff * current;
|
||||||
|
|
||||||
strided_input.index({"...", 0}) *= frame_opts.preemph_coeff;
|
strided_input.index({"...", 0}) *= 1 - frame_opts.preemph_coeff;
|
||||||
}
|
}
|
||||||
|
|
||||||
strided_input = feature_window_function_.Apply(strided_input);
|
strided_input = feature_window_function_.Apply(strided_input);
|
||||||
|
|
||||||
#if 0
|
int32_t padding = frame_opts.PaddedWindowSize() - strided_input.sizes()[1];
|
||||||
Vector<BaseFloat> window; // windowed waveform.
|
|
||||||
bool use_raw_log_energy = computer_.NeedRawLogEnergy();
|
|
||||||
for (int32 r = 0; r < rows_out; r++) { // r is frame index.
|
|
||||||
BaseFloat raw_log_energy = 0.0;
|
|
||||||
ExtractWindow(0, wave, r, computer_.GetFrameOptions(),
|
|
||||||
feature_window_function_, &window,
|
|
||||||
(use_raw_log_energy ? &raw_log_energy : NULL));
|
|
||||||
|
|
||||||
SubVector<BaseFloat> output_row(*output, r);
|
if (padding > 0) {
|
||||||
computer_.Compute(raw_log_energy, vtln_warp, &window, &output_row);
|
strided_input = torch::nn::functional::pad(
|
||||||
|
strided_input, torch::nn::functional::PadFuncOptions({0, padding})
|
||||||
|
.mode(torch::kConstant)
|
||||||
|
.value(0));
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
return computer_.Compute(log_energy_pre_window, vtln_warp, strided_input);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace kaldifeat
|
} // namespace kaldifeat
|
||||||
|
@ -79,6 +79,7 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
|
|||||||
SubVector<float> mel_energies(*feature, mel_offset, opts_.mel_opts.num_bins);
|
SubVector<float> mel_energies(*feature, mel_offset, opts_.mel_opts.num_bins);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// TODO(fangjun): remove the last column of spectrum
|
||||||
torch::Tensor mel_energies = mel_banks.Compute(spectrum);
|
torch::Tensor mel_energies = mel_banks.Compute(spectrum);
|
||||||
if (opts_.use_log_fbank) {
|
if (opts_.use_log_fbank) {
|
||||||
// Avoid log of zero (which should be prevented anyway by dithering).
|
// Avoid log of zero (which should be prevented anyway by dithering).
|
||||||
|
@ -131,7 +131,9 @@ MelBanks::MelBanks(const MelBanksOptions &opts,
|
|||||||
<< " and vtln-high " << vtln_high << ", versus "
|
<< " and vtln-high " << vtln_high << ", versus "
|
||||||
<< "low-freq " << low_freq << " and high-freq " << high_freq;
|
<< "low-freq " << low_freq << " and high-freq " << high_freq;
|
||||||
|
|
||||||
bins_mat_ = torch::zeros({num_bins, num_fft_bins}, torch::kFloat);
|
// TODO(fangjun): remove the last column of the power spectrum
|
||||||
|
// and set the number of columns to num_fft_bins instead of num_fft_bins + 1
|
||||||
|
bins_mat_ = torch::zeros({num_bins, num_fft_bins + 1}, torch::kFloat);
|
||||||
int32_t stride = bins_mat_.strides()[0];
|
int32_t stride = bins_mat_.strides()[0];
|
||||||
|
|
||||||
for (int32_t bin = 0; bin < num_bins; ++bin) {
|
for (int32_t bin = 0; bin < num_bins; ++bin) {
|
||||||
|
@ -63,6 +63,9 @@ class MelBanks {
|
|||||||
|
|
||||||
torch::Tensor Compute(const torch::Tensor &spectrum) const;
|
torch::Tensor Compute(const torch::Tensor &spectrum) const;
|
||||||
|
|
||||||
|
// for debug only
|
||||||
|
const torch::Tensor &GetBinsMat() const { return bins_mat_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// A 2-D matrix of shape [num_bins, num_fft_bins]
|
// A 2-D matrix of shape [num_bins, num_fft_bins]
|
||||||
torch::Tensor bins_mat_;
|
torch::Tensor bins_mat_;
|
||||||
|
@ -26,12 +26,17 @@ static void TestPreemph() {
|
|||||||
std::cout << d << "\n";
|
std::cout << d << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
static void TestPad() {
|
||||||
torch::Tensor a = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat);
|
torch::Tensor a = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat);
|
||||||
torch::Tensor b = torch::arange(1, 4).to(torch::kFloat).unsqueeze(0);
|
torch::Tensor b = torch::nn::functional::pad(
|
||||||
|
a, torch::nn::functional::PadFuncOptions({0, 3})
|
||||||
|
.mode(torch::kConstant)
|
||||||
|
.value(0));
|
||||||
std::cout << a << "\n";
|
std::cout << a << "\n";
|
||||||
std::cout << b << "\n";
|
std::cout << b << "\n";
|
||||||
std::cout << a * b << "\n";
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
TestPad();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
1
kaldifeat/python/CMakeLists.txt
Normal file
1
kaldifeat/python/CMakeLists.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
add_subdirectory(csrc)
|
4
kaldifeat/python/csrc/CMakeLists.txt
Normal file
4
kaldifeat/python/csrc/CMakeLists.txt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
add_definitions(-DTORCH_API_INCLUDE_EXTENSION_H)
|
||||||
|
pybind11_add_module(_kaldifeat kaldifeat.cc)
|
||||||
|
target_link_libraries(_kaldifeat PRIVATE kaldifeat_core)
|
||||||
|
target_link_libraries(_kaldifeat PRIVATE ${TORCH_DIR}/lib/libtorch_python.so)
|
28
kaldifeat/python/csrc/kaldifeat.cc
Normal file
28
kaldifeat/python/csrc/kaldifeat.cc
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
// kaldifeat/python/csrc/kaldifeat.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
#include "kaldifeat/python/csrc/kaldifeat.h"
|
||||||
|
|
||||||
|
#include "kaldifeat/csrc/feature-fbank.h"
|
||||||
|
#include "torch/torch.h"
|
||||||
|
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_kaldifeat, m) {
|
||||||
|
m.doc() = "Python wrapper for kaldifeat";
|
||||||
|
|
||||||
|
m.def("test", [](const torch::Tensor &tensor) -> torch::Tensor {
|
||||||
|
std::cout << "size: " << tensor.sizes() << "\n";
|
||||||
|
FbankOptions fbank_opts;
|
||||||
|
fbank_opts.frame_opts.dither = 0.0f;
|
||||||
|
|
||||||
|
Fbank fbank(fbank_opts);
|
||||||
|
float vtln_warp = 1.0f;
|
||||||
|
|
||||||
|
torch::Tensor ans = fbank.ComputeFeatures(tensor, vtln_warp);
|
||||||
|
return ans;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
11
kaldifeat/python/csrc/kaldifeat.h
Normal file
11
kaldifeat/python/csrc/kaldifeat.h
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
// kaldifeat/python/csrc/kaldifeat.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
#ifndef KALDIFEAT_PYTHON_CSRC_KALDIFEAT_H_
|
||||||
|
#define KALDIFEAT_PYTHON_CSRC_KALDIFEAT_H_
|
||||||
|
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
#endif // KALDIFEAT_PYTHON_CSRC_KALDIFEAT_H_
|
31
kaldifeat/python/tests/test.py
Executable file
31
kaldifeat/python/tests/test.py
Executable file
@ -0,0 +1,31 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(0, '/root/fangjun/open-source/kaldifeat/build/lib')
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import _kaldifeat
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# sox -n -r 16000 -b 16 abc.wav synth 1 sine 100
|
||||||
|
filename = '/root/fangjun/open-source/kaldi/src/featbin/abc.wav'
|
||||||
|
with sf.SoundFile(filename) as sf_desc:
|
||||||
|
sampling_rate = sf_desc.samplerate
|
||||||
|
assert sampling_rate == 16000
|
||||||
|
a = sf_desc.read(dtype=np.float32, always_2d=False)
|
||||||
|
a *= 32768
|
||||||
|
tensor = torch.from_numpy(a)
|
||||||
|
ans = _kaldifeat.test(tensor)
|
||||||
|
torch.set_printoptions(profile="full")
|
||||||
|
print(ans.shape)
|
||||||
|
print(ans)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user