Wrap Fbank to Python.

This commit is contained in:
Fangjun Kuang 2021-02-27 23:09:36 +08:00
parent 9a5567e21b
commit b2980cdffd
6 changed files with 21 additions and 31 deletions

View File

@ -42,14 +42,10 @@ class OfflineFeatureTpl {
int32_t Dim() const { return computer_.Dim(); } int32_t Dim() const { return computer_.Dim(); }
// Copy constructor. // Copy constructor.
OfflineFeatureTpl(const OfflineFeatureTpl<F> &other) OfflineFeatureTpl(const OfflineFeatureTpl<F> &) = delete;
: computer_(other.computer_), OfflineFeatureTpl<F> &operator=(const OfflineFeatureTpl<F> &) = delete;
feature_window_function_(other.feature_window_function_) {}
private: private:
// Disallow assignment.
OfflineFeatureTpl<F> &operator=(const OfflineFeatureTpl<F> &other);
F computer_; F computer_;
FeatureWindowFunction feature_window_function_; FeatureWindowFunction feature_window_function_;
}; };

View File

@ -26,14 +26,6 @@ FbankComputer::FbankComputer(const FbankOptions &opts) : opts_(opts) {
GetMelBanks(1.0f); GetMelBanks(1.0f);
} }
FbankComputer::FbankComputer(const FbankComputer &other)
: opts_(other.opts_),
log_energy_floor_(other.log_energy_floor_),
mel_banks_(other.mel_banks_) {
for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter)
iter->second = new MelBanks(*(iter->second));
}
FbankComputer::~FbankComputer() { FbankComputer::~FbankComputer() {
for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter) for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter)
delete iter->second; delete iter->second;

View File

@ -68,8 +68,7 @@ class FbankComputer {
~FbankComputer(); ~FbankComputer();
FbankComputer &operator=(const FbankComputer &) = delete; FbankComputer &operator=(const FbankComputer &) = delete;
FbankComputer(const FbankComputer &) = delete;
FbankComputer(const FbankComputer &other);
int32_t Dim() const { int32_t Dim() const {
return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0); return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0);

View File

@ -22,6 +22,12 @@ void PybindFbankOptions(py::module &m) {
.def("__str__", [](const FbankOptions &self) -> std::string { .def("__str__", [](const FbankOptions &self) -> std::string {
return self.ToString(); return self.ToString();
}); });
py::class_<Fbank>(m, "Fbank")
.def(py::init<const FbankOptions &>(), py::arg("opts"))
.def("dim", &Fbank::Dim)
.def("compute_features", &Fbank::ComputeFeatures, py::arg("wave"),
py::arg("vtln_warp"));
} }
} // namespace kaldifeat } // namespace kaldifeat

View File

@ -14,14 +14,9 @@
namespace kaldifeat { namespace kaldifeat {
static torch::Tensor Compute(const torch::Tensor &wave, static torch::Tensor Compute(const torch::Tensor &wave, Fbank *fbank) {
const FbankOptions &fbank_opts) {
// TODO(fangjun): wrap Fbank to Python
Fbank fbank(fbank_opts);
float vtln_warp = 1.0f; float vtln_warp = 1.0f;
torch::Tensor ans = fbank->ComputeFeatures(wave, vtln_warp);
torch::Tensor ans = fbank.ComputeFeatures(wave, vtln_warp);
return ans; return ans;
} }
@ -32,18 +27,18 @@ PYBIND11_MODULE(_kaldifeat, m) {
PybindMelBanksOptions(m); PybindMelBanksOptions(m);
PybindFbankOptions(m); PybindFbankOptions(m);
m.def("compute", &Compute, py::arg("wave"), py::arg("fbank_opts")); m.def("compute", &Compute, py::arg("wave"), py::arg("fbank"));
// It verifies that the reimplementation produces the same output // It verifies that the reimplementation produces the same output
// as kaldi using default parameters with dither disabled. // as kaldi using default parameters with dither disabled.
m.def( m.def(
"_compute_with_elapsed_time", // for benchmark only "_compute_with_elapsed_time", // for benchmark only
[](const torch::Tensor &wave, [](const torch::Tensor &wave,
const FbankOptions &fbank_opts) -> std::pair<torch::Tensor, double> { Fbank *fbank) -> std::pair<torch::Tensor, double> {
std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::time_point begin =
std::chrono::steady_clock::now(); std::chrono::steady_clock::now();
torch::Tensor ans = Compute(wave, fbank_opts); torch::Tensor ans = Compute(wave, fbank);
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::time_point end =
std::chrono::steady_clock::now(); std::chrono::steady_clock::now();
@ -55,7 +50,7 @@ PYBIND11_MODULE(_kaldifeat, m) {
return std::make_pair(ans, elapsed_seconds); return std::make_pair(ans, elapsed_seconds);
}, },
py::arg("wave"), py::arg("fbank_opts")); py::arg("wave"), py::arg("fbank"));
} }
} // namespace kaldifeat } // namespace kaldifeat

View File

@ -58,11 +58,11 @@ def read_wave() -> torch.Tensor:
def test_and_benchmark_default_parameters(): def test_and_benchmark_default_parameters():
fbank_opts = _kaldifeat.FbankOptions() fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0 fbank_opts.frame_opts.dither = 0
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave() data = read_wave()
ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time( ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(data, fbank)
data, fbank_opts)
expected = read_ark_txt() expected = read_ark_txt()
assert torch.allclose(ans, expected, rtol=1e-3) assert torch.allclose(ans, expected, rtol=1e-3)
@ -74,10 +74,11 @@ def test_use_energy_htk_compat_true():
fbank_opts.frame_opts.dither = 0 fbank_opts.frame_opts.dither = 0
fbank_opts.use_energy = True fbank_opts.use_energy = True
fbank_opts.htk_compat = True fbank_opts.htk_compat = True
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave() data = read_wave()
ans = _kaldifeat.compute(data, fbank_opts) ans = _kaldifeat.compute(data, fbank)
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:abc.scp ark,t:abc.txt # ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:abc.scp ark,t:abc.txt
# the first 3 rows are: # the first 3 rows are:
@ -95,10 +96,11 @@ def test_use_energy_htk_compat_false():
fbank_opts.frame_opts.dither = 0 fbank_opts.frame_opts.dither = 0
fbank_opts.use_energy = True fbank_opts.use_energy = True
fbank_opts.htk_compat = False fbank_opts.htk_compat = False
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave() data = read_wave()
ans = _kaldifeat.compute(data, fbank_opts) ans = _kaldifeat.compute(data, fbank)
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=0 scp:abc.scp ark,t:abc.txt # ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=0 scp:abc.scp ark,t:abc.txt
# the first 3 rows are: # the first 3 rows are: