mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 18:12:17 +00:00
Wrap Fbank to Python.
This commit is contained in:
parent
9a5567e21b
commit
b2980cdffd
@ -42,14 +42,10 @@ class OfflineFeatureTpl {
|
||||
int32_t Dim() const { return computer_.Dim(); }
|
||||
|
||||
// Copy constructor.
|
||||
OfflineFeatureTpl(const OfflineFeatureTpl<F> &other)
|
||||
: computer_(other.computer_),
|
||||
feature_window_function_(other.feature_window_function_) {}
|
||||
OfflineFeatureTpl(const OfflineFeatureTpl<F> &) = delete;
|
||||
OfflineFeatureTpl<F> &operator=(const OfflineFeatureTpl<F> &) = delete;
|
||||
|
||||
private:
|
||||
// Disallow assignment.
|
||||
OfflineFeatureTpl<F> &operator=(const OfflineFeatureTpl<F> &other);
|
||||
|
||||
F computer_;
|
||||
FeatureWindowFunction feature_window_function_;
|
||||
};
|
||||
|
@ -26,14 +26,6 @@ FbankComputer::FbankComputer(const FbankOptions &opts) : opts_(opts) {
|
||||
GetMelBanks(1.0f);
|
||||
}
|
||||
|
||||
FbankComputer::FbankComputer(const FbankComputer &other)
|
||||
: opts_(other.opts_),
|
||||
log_energy_floor_(other.log_energy_floor_),
|
||||
mel_banks_(other.mel_banks_) {
|
||||
for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter)
|
||||
iter->second = new MelBanks(*(iter->second));
|
||||
}
|
||||
|
||||
FbankComputer::~FbankComputer() {
|
||||
for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter)
|
||||
delete iter->second;
|
||||
|
@ -68,8 +68,7 @@ class FbankComputer {
|
||||
~FbankComputer();
|
||||
|
||||
FbankComputer &operator=(const FbankComputer &) = delete;
|
||||
|
||||
FbankComputer(const FbankComputer &other);
|
||||
FbankComputer(const FbankComputer &) = delete;
|
||||
|
||||
int32_t Dim() const {
|
||||
return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0);
|
||||
|
@ -22,6 +22,12 @@ void PybindFbankOptions(py::module &m) {
|
||||
.def("__str__", [](const FbankOptions &self) -> std::string {
|
||||
return self.ToString();
|
||||
});
|
||||
|
||||
py::class_<Fbank>(m, "Fbank")
|
||||
.def(py::init<const FbankOptions &>(), py::arg("opts"))
|
||||
.def("dim", &Fbank::Dim)
|
||||
.def("compute_features", &Fbank::ComputeFeatures, py::arg("wave"),
|
||||
py::arg("vtln_warp"));
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
@ -14,14 +14,9 @@
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
static torch::Tensor Compute(const torch::Tensor &wave,
|
||||
const FbankOptions &fbank_opts) {
|
||||
// TODO(fangjun): wrap Fbank to Python
|
||||
|
||||
Fbank fbank(fbank_opts);
|
||||
static torch::Tensor Compute(const torch::Tensor &wave, Fbank *fbank) {
|
||||
float vtln_warp = 1.0f;
|
||||
|
||||
torch::Tensor ans = fbank.ComputeFeatures(wave, vtln_warp);
|
||||
torch::Tensor ans = fbank->ComputeFeatures(wave, vtln_warp);
|
||||
return ans;
|
||||
}
|
||||
|
||||
@ -32,18 +27,18 @@ PYBIND11_MODULE(_kaldifeat, m) {
|
||||
PybindMelBanksOptions(m);
|
||||
PybindFbankOptions(m);
|
||||
|
||||
m.def("compute", &Compute, py::arg("wave"), py::arg("fbank_opts"));
|
||||
m.def("compute", &Compute, py::arg("wave"), py::arg("fbank"));
|
||||
|
||||
// It verifies that the reimplementation produces the same output
|
||||
// as kaldi using default parameters with dither disabled.
|
||||
m.def(
|
||||
"_compute_with_elapsed_time", // for benchmark only
|
||||
[](const torch::Tensor &wave,
|
||||
const FbankOptions &fbank_opts) -> std::pair<torch::Tensor, double> {
|
||||
Fbank *fbank) -> std::pair<torch::Tensor, double> {
|
||||
std::chrono::steady_clock::time_point begin =
|
||||
std::chrono::steady_clock::now();
|
||||
|
||||
torch::Tensor ans = Compute(wave, fbank_opts);
|
||||
torch::Tensor ans = Compute(wave, fbank);
|
||||
|
||||
std::chrono::steady_clock::time_point end =
|
||||
std::chrono::steady_clock::now();
|
||||
@ -55,7 +50,7 @@ PYBIND11_MODULE(_kaldifeat, m) {
|
||||
|
||||
return std::make_pair(ans, elapsed_seconds);
|
||||
},
|
||||
py::arg("wave"), py::arg("fbank_opts"));
|
||||
py::arg("wave"), py::arg("fbank"));
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
@ -58,11 +58,11 @@ def read_wave() -> torch.Tensor:
|
||||
def test_and_benchmark_default_parameters():
|
||||
fbank_opts = _kaldifeat.FbankOptions()
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||
|
||||
data = read_wave()
|
||||
|
||||
ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(
|
||||
data, fbank_opts)
|
||||
ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(data, fbank)
|
||||
|
||||
expected = read_ark_txt()
|
||||
assert torch.allclose(ans, expected, rtol=1e-3)
|
||||
@ -74,10 +74,11 @@ def test_use_energy_htk_compat_true():
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank_opts.use_energy = True
|
||||
fbank_opts.htk_compat = True
|
||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||
|
||||
data = read_wave()
|
||||
|
||||
ans = _kaldifeat.compute(data, fbank_opts)
|
||||
ans = _kaldifeat.compute(data, fbank)
|
||||
|
||||
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:abc.scp ark,t:abc.txt
|
||||
# the first 3 rows are:
|
||||
@ -95,10 +96,11 @@ def test_use_energy_htk_compat_false():
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank_opts.use_energy = True
|
||||
fbank_opts.htk_compat = False
|
||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||
|
||||
data = read_wave()
|
||||
|
||||
ans = _kaldifeat.compute(data, fbank_opts)
|
||||
ans = _kaldifeat.compute(data, fbank)
|
||||
|
||||
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=0 scp:abc.scp ark,t:abc.txt
|
||||
# the first 3 rows are:
|
||||
|
Loading…
x
Reference in New Issue
Block a user