wrap FbankOptions to Python.

This commit is contained in:
Fangjun Kuang 2021-02-27 22:52:39 +08:00
parent 53504a705c
commit 9a5567e21b
19 changed files with 505 additions and 93 deletions

View File

@ -15,15 +15,14 @@ template <class F>
torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
float vtln_warp) {
KALDIFEAT_ASSERT(wave.dim() == 1);
int32_t rows_out = NumFrames(wave.sizes()[0], computer_.GetFrameOptions());
int32_t cols_out = computer_.Dim();
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
torch::Tensor strided_input = GetStrided(wave, frame_opts);
if (frame_opts.dither != 0)
if (frame_opts.dither != 0.0f) {
strided_input = Dither(strided_input, frame_opts.dither);
}
if (frame_opts.remove_dc_offset) {
torch::Tensor row_means = strided_input.mean(1).unsqueeze(1);
@ -37,12 +36,14 @@ torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
constexpr float kEps = 1.1920928955078125e-07f;
if (use_raw_log_energy) {
// it is true iff use_energy==true and row_energy==true
log_energy_pre_window =
torch::clamp_min(strided_input.pow(2).sum(1), kEps).log();
}
if (frame_opts.preemph_coeff != 0.0f)
if (frame_opts.preemph_coeff != 0.0f) {
Preemphasize(frame_opts.preemph_coeff, &strided_input);
}
feature_window_function_.Apply(&strided_input);

View File

@ -13,6 +13,11 @@
namespace kaldifeat {
std::ostream &operator<<(std::ostream &os, const FbankOptions &opts) {
os << opts.ToString();
return os;
}
FbankComputer::FbankComputer(const FbankOptions &opts) : opts_(opts) {
if (opts.energy_floor > 0.0f) log_energy_floor_ = logf(opts.energy_floor);
@ -78,11 +83,6 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
// Use power instead of magnitude if requested.
if (opts_.use_power) spectrum.pow_(2);
#if 0
int32_t mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0);
SubVector<float> mel_energies(*feature, mel_offset, opts_.mel_opts.num_bins);
#endif
// TODO(fangjun): remove the last column of spectrum
torch::Tensor mel_energies = mel_banks.Compute(spectrum);
if (opts_.use_log_fbank) {
@ -90,17 +90,26 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
mel_energies = torch::clamp_min(mel_energies, kEps).log();
}
// if use_energy is true, then we get an extra bin. That is,
// if num_mel_bins is 23, the feature will contain 24 bins.
//
// if htk_compat is false, then the 0th bin is the log energy
// if htk_compat is true, then the last bin is the log energy
// Copy energy as first value (or the last, if htk_compat == true).
if (opts_.use_energy) {
#if 0
if (opts_.energy_floor > 0.0 && signal_raw_log_energy < log_energy_floor_) {
signal_raw_log_energy = log_energy_floor_;
if (opts_.energy_floor > 0.0f) {
signal_raw_log_energy =
torch::clamp_min(signal_raw_log_energy, log_energy_floor_);
}
#endif
int32_t energy_index = opts_.htk_compat ? opts_.mel_opts.num_bins : 0;
energy_index = 0; // TODO(fangjun): fix it
mel_energies.index({"...", energy_index}) = signal_raw_log_energy;
signal_raw_log_energy.unsqueeze_(1);
if (opts_.htk_compat) {
mel_energies = torch::cat({mel_energies, signal_raw_log_energy}, 1);
} else {
mel_energies = torch::cat({signal_raw_log_energy, mel_energies}, 1);
}
}
return mel_energies;

View File

@ -20,12 +20,16 @@ struct FbankOptions {
MelBanksOptions mel_opts;
// append an extra dimension with energy to the filter banks
bool use_energy = false;
float energy_floor = 0.0f;
float energy_floor = 0.0f; // active iff use_energy==true
// If true, compute log_energy before preemphasis and windowing
// If false, compute log_energy after preemphasis ans windowing
bool raw_energy = true; // active iff use_energy==true
// If true, compute energy before preemphasis and windowing
bool raw_energy = true;
// If true, put energy last (if using energy)
bool htk_compat = false;
// If false, put energy first
bool htk_compat = false; // active iff use_energy==true
// if true (default), produce log-filterbank, else linear
bool use_log_fbank = true;
@ -34,8 +38,28 @@ struct FbankOptions {
bool use_power = true;
FbankOptions() { mel_opts.num_bins = 23; }
std::string ToString() const {
std::ostringstream os;
os << "frame_opts: \n";
os << frame_opts << "\n";
os << "\n";
os << "mel_opts: \n";
os << mel_opts << "\n";
os << "use_energy: " << use_energy << "\n";
os << "energy_floor: " << energy_floor << "\n";
os << "raw_energy: " << raw_energy << "\n";
os << "htk_compat: " << htk_compat << "\n";
os << "use_log_fbank: " << use_log_fbank << "\n";
os << "use_power: " << use_power << "\n";
return os.str();
}
};
std::ostream &operator<<(std::ostream &os, const FbankOptions &opts);
class FbankComputer {
public:
using Options = FbankOptions;
@ -51,12 +75,15 @@ class FbankComputer {
return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0);
}
// if true, compute log_energy_pre_window but after dithering and dc removal
bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; }
const FrameExtractionOptions &GetFrameOptions() const {
return opts_.frame_opts;
}
// signal_raw_log_energy is log_energy_pre_window, which is not empty
// iff NeedRawLogEnergy() returns true.
torch::Tensor Compute(torch::Tensor signal_raw_log_energy, float vtln_warp,
const torch::Tensor &signal_frame);

View File

@ -16,6 +16,11 @@
namespace kaldifeat {
std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts) {
os << opts.ToString();
return os;
}
FeatureWindowFunction::FeatureWindowFunction(
const FrameExtractionOptions &opts) {
int32_t frame_length = opts.WindowSize();

View File

@ -39,9 +39,9 @@ struct FrameExtractionOptions {
bool round_to_power_of_two = true;
float blackman_coeff = 0.42f;
bool snip_edges = true;
bool allow_downsample = false;
bool allow_upsample = false;
int32_t max_feature_vectors = -1;
// bool allow_downsample = false;
// bool allow_upsample = false;
// int32_t max_feature_vectors = -1;
int32_t WindowShift() const {
return static_cast<int32_t>(samp_freq * 0.001f * frame_shift_ms);
@ -53,8 +53,29 @@ struct FrameExtractionOptions {
return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(WindowSize())
: WindowSize());
}
std::string ToString() const {
std::ostringstream os;
#define KALDIFEAT_PRINT(x) os << #x << ": " << x << "\n"
KALDIFEAT_PRINT(samp_freq);
KALDIFEAT_PRINT(frame_shift_ms);
KALDIFEAT_PRINT(frame_length_ms);
KALDIFEAT_PRINT(dither);
KALDIFEAT_PRINT(preemph_coeff);
KALDIFEAT_PRINT(remove_dc_offset);
KALDIFEAT_PRINT(window_type);
KALDIFEAT_PRINT(round_to_power_of_two);
KALDIFEAT_PRINT(blackman_coeff);
KALDIFEAT_PRINT(snip_edges);
// KALDIFEAT_PRINT(allow_downsample);
// KALDIFEAT_PRINT(allow_upsample);
// KALDIFEAT_PRINT(max_feature_vectors);
#undef KALDIFEAT_PRINT
return os.str();
}
};
std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts);
class FeatureWindowFunction {
public:
FeatureWindowFunction() = default;

View File

@ -10,6 +10,11 @@
namespace kaldifeat {
std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts) {
os << opts.ToString();
return os;
}
float MelBanks::VtlnWarpFreq(
float vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN.
float vtln_high_cutoff,

View File

@ -32,8 +32,22 @@ struct MelBanksOptions {
// Enables more exact compatibility with HTK, for testing purposes. Affects
// mel-energy flooring and reproduces a bug in HTK.
bool htk_mode = false;
std::string ToString() const {
std::ostringstream os;
os << "num_bins: " << num_bins << "\n";
os << "low_freq: " << low_freq << "\n";
os << "high_freq: " << high_freq << "\n";
os << "vtln_low: " << vtln_low << "\n";
os << "vtln_high: " << vtln_high << "\n";
os << "debug_mel: " << debug_mel << "\n";
os << "htk_mode: " << htk_mode << "\n";
return os.str();
}
};
std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts);
class MelBanks {
public:
static inline float InverseMelScale(float mel_freq) {

View File

@ -62,8 +62,18 @@ static void TestDither() {
std::cout << (a + b * 2) << "\n";
}
static void TestCat() {
torch::Tensor a = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat);
torch::Tensor b = torch::arange(0, 2).reshape({2, 1}).to(torch::kFloat) * 0.1;
torch::Tensor c = torch::cat({a, b}, 1);
torch::Tensor d = torch::cat({b, a}, 1);
std::cout << a << "\n";
std::cout << b << "\n";
std::cout << c << "\n";
std::cout << d << "\n";
}
int main() {
// TestDither();
TestGetStrided();
TestCat();
return 0;
}

View File

@ -1,4 +1,9 @@
add_definitions(-DTORCH_API_INCLUDE_EXTENSION_H)
pybind11_add_module(_kaldifeat kaldifeat.cc)
pybind11_add_module(_kaldifeat
feature-fbank.cc
feature-window.cc
kaldifeat.cc
mel-computations.cc
)
target_link_libraries(_kaldifeat PRIVATE kaldifeat_core)
target_link_libraries(_kaldifeat PRIVATE ${TORCH_DIR}/lib/libtorch_python.so)

View File

@ -0,0 +1,27 @@
// kaldifeat/python/csrc/feature-fbank.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#include "kaldifeat/python/csrc/feature-fbank.h"
#include "kaldifeat/csrc/feature-fbank.h"
namespace kaldifeat {
void PybindFbankOptions(py::module &m) {
py::class_<FbankOptions>(m, "FbankOptions")
.def(py::init<>())
.def_readwrite("frame_opts", &FbankOptions::frame_opts)
.def_readwrite("mel_opts", &FbankOptions::mel_opts)
.def_readwrite("use_energy", &FbankOptions::use_energy)
.def_readwrite("energy_floor", &FbankOptions::energy_floor)
.def_readwrite("raw_energy", &FbankOptions::raw_energy)
.def_readwrite("htk_compat", &FbankOptions::htk_compat)
.def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank)
.def_readwrite("use_power", &FbankOptions::use_power)
.def("__str__", [](const FbankOptions &self) -> std::string {
return self.ToString();
});
}
} // namespace kaldifeat

View File

@ -0,0 +1,16 @@
// kaldifeat/python/csrc/feature-fbank.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#ifndef KALDIFEAT_PYTHON_CSRC_FEATURE_FBANK_H_
#define KALDIFEAT_PYTHON_CSRC_FEATURE_FBANK_H_
#include "kaldifeat/python/csrc/kaldifeat.h"
namespace kaldifeat {
void PybindFbankOptions(py::module &m);
} // namespace kaldifeat
#endif // KALDIFEAT_PYTHON_CSRC_FEATURE_FBANK_H_

View File

@ -0,0 +1,39 @@
// kaldifeat/python/csrc/feature-window.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#include "kaldifeat/python/csrc/feature-window.h"
#include "kaldifeat/csrc/feature-window.h"
namespace kaldifeat {
void PybindFrameExtractionOptions(py::module &m) {
py::class_<FrameExtractionOptions>(m, "FrameExtractionOptions")
.def(py::init<>())
.def_readwrite("samp_freq", &FrameExtractionOptions::samp_freq)
.def_readwrite("frame_shift_ms", &FrameExtractionOptions::frame_shift_ms)
.def_readwrite("frame_length_ms",
&FrameExtractionOptions::frame_length_ms)
.def_readwrite("dither", &FrameExtractionOptions::dither)
.def_readwrite("preemph_coeff", &FrameExtractionOptions::preemph_coeff)
.def_readwrite("remove_dc_offset",
&FrameExtractionOptions::remove_dc_offset)
.def_readwrite("window_type", &FrameExtractionOptions::window_type)
.def_readwrite("round_to_power_of_two",
&FrameExtractionOptions::round_to_power_of_two)
.def_readwrite("blackman_coeff", &FrameExtractionOptions::blackman_coeff)
.def_readwrite("snip_edges", &FrameExtractionOptions::snip_edges)
#if 0
.def_readwrite("allow_downsample",
&FrameExtractionOptions::allow_downsample)
.def_readwrite("allow_upsample", &FrameExtractionOptions::allow_upsample)
.def_readwrite("max_feature_vectors",
&FrameExtractionOptions::max_feature_vectors)
#endif
.def("__str__", [](const FrameExtractionOptions &self) -> std::string {
return self.ToString();
});
}
} // namespace kaldifeat

View File

@ -0,0 +1,16 @@
// kaldifeat/python/csrc/feature-window.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#ifndef KALDIFEAT_PYTHON_CSRC_FEATURE_WINDOW_H_
#define KALDIFEAT_PYTHON_CSRC_FEATURE_WINDOW_H_
#include "kaldifeat/python/csrc/kaldifeat.h"
namespace kaldifeat {
void PybindFrameExtractionOptions(py::module &m);
} // namespace kaldifeat
#endif // KALDIFEAT_PYTHON_CSRC_FEATURE_WINDOW_H_

View File

@ -7,36 +7,55 @@
#include <chrono>
#include "kaldifeat/csrc/feature-fbank.h"
#include "kaldifeat/python/csrc/feature-fbank.h"
#include "kaldifeat/python/csrc/feature-window.h"
#include "kaldifeat/python/csrc/mel-computations.h"
#include "torch/torch.h"
namespace kaldifeat {
static torch::Tensor Compute(const torch::Tensor &wave,
const FbankOptions &fbank_opts) {
// TODO(fangjun): wrap Fbank to Python
Fbank fbank(fbank_opts);
float vtln_warp = 1.0f;
torch::Tensor ans = fbank.ComputeFeatures(wave, vtln_warp);
return ans;
}
PYBIND11_MODULE(_kaldifeat, m) {
m.doc() = "Python wrapper for kaldifeat";
PybindFrameExtractionOptions(m);
PybindMelBanksOptions(m);
PybindFbankOptions(m);
m.def("compute", &Compute, py::arg("wave"), py::arg("fbank_opts"));
// It verifies that the reimplementation produces the same output
// as kaldi using default paremters with dither disabled.
m.def("test_default_parameters",
[](const torch::Tensor &tensor) -> std::pair<torch::Tensor, double> {
FbankOptions fbank_opts;
fbank_opts.frame_opts.dither = 0.0f;
// as kaldi using default parameters with dither disabled.
m.def(
"_compute_with_elapsed_time", // for benchmark only
[](const torch::Tensor &wave,
const FbankOptions &fbank_opts) -> std::pair<torch::Tensor, double> {
std::chrono::steady_clock::time_point begin =
std::chrono::steady_clock::now();
Fbank fbank(fbank_opts);
float vtln_warp = 1.0f;
torch::Tensor ans = Compute(wave, fbank_opts);
std::chrono::steady_clock::time_point begin =
std::chrono::steady_clock::now();
std::chrono::steady_clock::time_point end =
std::chrono::steady_clock::now();
torch::Tensor ans = fbank.ComputeFeatures(tensor, vtln_warp);
std::chrono::steady_clock::time_point end =
std::chrono::steady_clock::now();
double elapsed_seconds =
std::chrono::duration_cast<std::chrono::microseconds>(end - begin)
.count() /
1000000.;
double elapsed_seconds =
std::chrono::duration_cast<std::chrono::microseconds>(end - begin)
.count() /
1000000.;
return std::make_pair(ans, elapsed_seconds);
});
return std::make_pair(ans, elapsed_seconds);
},
py::arg("wave"), py::arg("fbank_opts"));
}
} // namespace kaldifeat

View File

@ -0,0 +1,27 @@
// kaldifeat/python/csrc/mel-computations.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#include "kaldifeat/csrc/mel-computations.h"
#include "kaldifeat/python/csrc/feature-window.h"
namespace kaldifeat {
void PybindMelBanksOptions(py::module &m) {
py::class_<MelBanksOptions>(m, "MelBanksOptions")
.def(py::init<>())
.def_readwrite("num_bins", &MelBanksOptions::num_bins)
.def_readwrite("low_freq", &MelBanksOptions::low_freq)
.def_readwrite("high_freq", &MelBanksOptions::high_freq)
.def_readwrite("vtln_low", &MelBanksOptions::vtln_low)
.def_readwrite("vtln_high", &MelBanksOptions::vtln_high)
.def_readwrite("debug_mel", &MelBanksOptions::debug_mel)
.def_readwrite("htk_mode", &MelBanksOptions::htk_mode)
.def("__str__", [](const MelBanksOptions &self) -> std::string {
return self.ToString();
});
;
}
} // namespace kaldifeat

View File

@ -0,0 +1,16 @@
// kaldifeat/python/csrc/mel-computations.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#ifndef KALDIFEAT_PYTHON_CSRC_MEL_COMPUTATIONS_H_
#define KALDIFEAT_PYTHON_CSRC_MEL_COMPUTATIONS_H_
#include "kaldifeat/python/csrc/kaldifeat.h"
namespace kaldifeat {
void PybindMelBanksOptions(py::module &m);
} // namespace kaldifeat
#endif // KALDIFEAT_PYTHON_CSRC_MEL_COMPUTATIONS_H_

View File

@ -1,49 +0,0 @@
#!/usr/bin/env python3
#
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
from pathlib import Path
cur_dir = Path(__file__).resolve().parent
kaldi_feat_dir = cur_dir.parent.parent.parent
import sys
sys.path.insert(0, f'{kaldi_feat_dir}/build/lib')
import numpy as np
import soundfile as sf
import torch
import _kaldifeat
def read_ark_txt() -> torch.Tensor:
test_data_dir = cur_dir / 'test_data'
filename = test_data_dir / 'abc.txt'
features = []
with open(filename) as f:
for line in f:
if '[' in line: continue
line = line.strip('').split()
data = [float(d) for d in line if d != ']']
features.append(data)
ans = torch.tensor(features)
return ans
def main():
test_data_dir = cur_dir / 'test_data'
filename = test_data_dir / 'abc.wav'
with sf.SoundFile(filename) as sf_desc:
sampling_rate = sf_desc.samplerate
assert sampling_rate == 16000
data = sf_desc.read(dtype=np.float32, always_2d=False)
data *= 32768
tensor = torch.from_numpy(data)
ans, elapsed_seconds = _kaldifeat.test_default_parameters(tensor)
expected = read_ark_txt()
assert torch.allclose(ans, expected, rtol=1e-3)
print('elapsed seconds:', elapsed_seconds)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,121 @@
#!/usr/bin/env python3
#
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
from pathlib import Path
cur_dir = Path(__file__).resolve().parent
kaldi_feat_dir = cur_dir.parent.parent.parent
import sys
sys.path.insert(0, f'{kaldi_feat_dir}/build/lib')
import numpy as np
import soundfile as sf
import torch
import _kaldifeat
def read_ark_txt() -> torch.Tensor:
test_data_dir = cur_dir / 'test_data'
filename = test_data_dir / 'abc.txt'
features = []
with open(filename) as f:
for line in f:
if '[' in line: continue
line = line.strip('').split()
data = [float(d) for d in line if d != ']']
features.append(data)
ans = torch.tensor(features)
return ans
def parse_str(s) -> torch.Tensor:
'''
Args:
s:
It consists of several lines. Each line contains several numbers
separated by spaces.
'''
ans = []
for line in s.strip().split('\n'):
data = [float(d) for d in line.strip().split()]
ans.append(data)
return torch.tensor(ans)
def read_wave() -> torch.Tensor:
test_data_dir = cur_dir / 'test_data'
filename = test_data_dir / 'abc.wav'
with sf.SoundFile(filename) as sf_desc:
sampling_rate = sf_desc.samplerate
assert sampling_rate == 16000
data = sf_desc.read(dtype=np.float32, always_2d=False)
data *= 32768
return torch.from_numpy(data)
def test_and_benchmark_default_parameters():
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
data = read_wave()
ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(
data, fbank_opts)
expected = read_ark_txt()
assert torch.allclose(ans, expected, rtol=1e-3)
print('elapsed seconds:', elapsed_seconds)
def test_use_energy_htk_compat_true():
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.use_energy = True
fbank_opts.htk_compat = True
data = read_wave()
ans = _kaldifeat.compute(data, fbank_opts)
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:abc.scp ark,t:abc.txt
# the first 3 rows are:
expected_str = '''
15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303 25.38995
15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113 25.38996
15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073 25.38995
'''
expected = parse_str(expected_str)
assert torch.allclose(ans[:3, :], expected, rtol=1e-3)
def test_use_energy_htk_compat_false():
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.use_energy = True
fbank_opts.htk_compat = False
data = read_wave()
ans = _kaldifeat.compute(data, fbank_opts)
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=0 scp:abc.scp ark,t:abc.txt
# the first 3 rows are:
expected_str = '''
25.38995 15.576 21.93211 25.55334 24.08283 15.93041 12.47176 10.47909 9.024426 7.899537 6.935482 6.21563 6.035741 6.140291 5.94696 6.146772 6.860236 6.702379 7.087324 6.929666 7.66336 7.935287 8.405977 8.309303
25.38996 15.5755 21.93212 25.55334 24.08282 15.93044 12.47107 10.47753 9.026523 7.901362 6.939464 6.189109 5.926141 5.678882 5.553694 6.006057 6.066478 6.500169 7.277717 7.248817 7.699819 7.990362 8.033764 8.220113
25.38995 15.57543 21.93211 25.55334 24.08282 15.93052 12.47129 10.4782 9.028108 7.90429 6.946663 6.310408 5.903729 5.777827 6.027511 6.000434 6.190129 5.968217 6.455313 7.450428 7.993948 8.512851 8.341401 8.14073
'''
expected = parse_str(expected_str)
assert torch.allclose(ans[:3, :], expected, rtol=1e-3)
def main():
test_and_benchmark_default_parameters()
test_use_energy_htk_compat_true()
test_use_energy_htk_compat_false()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,83 @@
#!/usr/bin/env python3
#
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
from pathlib import Path
cur_dir = Path(__file__).resolve().parent
kaldi_feat_dir = cur_dir.parent.parent.parent
import sys
sys.path.insert(0, f'{kaldi_feat_dir}/build/lib')
import torch
import _kaldifeat
def test_frame_extraction_options():
opts = _kaldifeat.FrameExtractionOptions()
opts.samp_freq = 220500
opts.frame_shift_ms = 15
opts.frame_length_ms = 40
opts.dither = 0.1
opts.preemph_coeff = 0.98
opts.remove_dc_offset = False
opts.window_type = 'hanning'
opts.round_to_power_of_two = False
opts.blackman_coeff = 0.422
opts.snip_edges = False
print(opts)
def test_mel_banks_options():
opts = _kaldifeat.MelBanksOptions()
opts.num_bins = 23
opts.low_freq = 21
opts.high_freq = 8000
opts.vtln_low = 101
opts.vtln_high = -501
opts.debug_mel = True
opts.htk_mode = True
print(opts)
def test_fbank_options():
opts = _kaldifeat.FbankOptions()
frame_opts = opts.frame_opts
mel_opts = opts.mel_opts
opts.energy_floor = 0
opts.htk_compat = False
opts.raw_energy = True
opts.use_energy = False
opts.use_log_fbank = True
opts.use_power = True
frame_opts.blackman_coeff = 0.42
frame_opts.dither = 1
frame_opts.frame_length_ms = 25
frame_opts.frame_shift_ms = 10
frame_opts.preemph_coeff = 0.97
frame_opts.remove_dc_offset = True
frame_opts.round_to_power_of_two = True
frame_opts.samp_freq = 16000
frame_opts.snip_edges = True
frame_opts.window_type = 'povey'
mel_opts.debug_mel = True
mel_opts.high_freq = 0
mel_opts.low_freq = 20
mel_opts.num_bins = 23
mel_opts.vtln_high = -500
mel_opts.vtln_low = 100
print(opts)
def main():
# test_frame_extraction_options()
# test_mel_banks_options()
test_fbank_options()
if __name__ == '__main__':
main()