mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-10 02:22:16 +00:00
Wrap Fbank to Python.
This commit is contained in:
parent
9a5567e21b
commit
b2980cdffd
@ -42,14 +42,10 @@ class OfflineFeatureTpl {
|
|||||||
int32_t Dim() const { return computer_.Dim(); }
|
int32_t Dim() const { return computer_.Dim(); }
|
||||||
|
|
||||||
// Copy constructor.
|
// Copy constructor.
|
||||||
OfflineFeatureTpl(const OfflineFeatureTpl<F> &other)
|
OfflineFeatureTpl(const OfflineFeatureTpl<F> &) = delete;
|
||||||
: computer_(other.computer_),
|
OfflineFeatureTpl<F> &operator=(const OfflineFeatureTpl<F> &) = delete;
|
||||||
feature_window_function_(other.feature_window_function_) {}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Disallow assignment.
|
|
||||||
OfflineFeatureTpl<F> &operator=(const OfflineFeatureTpl<F> &other);
|
|
||||||
|
|
||||||
F computer_;
|
F computer_;
|
||||||
FeatureWindowFunction feature_window_function_;
|
FeatureWindowFunction feature_window_function_;
|
||||||
};
|
};
|
||||||
|
@ -26,14 +26,6 @@ FbankComputer::FbankComputer(const FbankOptions &opts) : opts_(opts) {
|
|||||||
GetMelBanks(1.0f);
|
GetMelBanks(1.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
FbankComputer::FbankComputer(const FbankComputer &other)
|
|
||||||
: opts_(other.opts_),
|
|
||||||
log_energy_floor_(other.log_energy_floor_),
|
|
||||||
mel_banks_(other.mel_banks_) {
|
|
||||||
for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter)
|
|
||||||
iter->second = new MelBanks(*(iter->second));
|
|
||||||
}
|
|
||||||
|
|
||||||
FbankComputer::~FbankComputer() {
|
FbankComputer::~FbankComputer() {
|
||||||
for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter)
|
for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter)
|
||||||
delete iter->second;
|
delete iter->second;
|
||||||
|
@ -68,8 +68,7 @@ class FbankComputer {
|
|||||||
~FbankComputer();
|
~FbankComputer();
|
||||||
|
|
||||||
FbankComputer &operator=(const FbankComputer &) = delete;
|
FbankComputer &operator=(const FbankComputer &) = delete;
|
||||||
|
FbankComputer(const FbankComputer &) = delete;
|
||||||
FbankComputer(const FbankComputer &other);
|
|
||||||
|
|
||||||
int32_t Dim() const {
|
int32_t Dim() const {
|
||||||
return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0);
|
return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0);
|
||||||
|
@ -22,6 +22,12 @@ void PybindFbankOptions(py::module &m) {
|
|||||||
.def("__str__", [](const FbankOptions &self) -> std::string {
|
.def("__str__", [](const FbankOptions &self) -> std::string {
|
||||||
return self.ToString();
|
return self.ToString();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
py::class_<Fbank>(m, "Fbank")
|
||||||
|
.def(py::init<const FbankOptions &>(), py::arg("opts"))
|
||||||
|
.def("dim", &Fbank::Dim)
|
||||||
|
.def("compute_features", &Fbank::ComputeFeatures, py::arg("wave"),
|
||||||
|
py::arg("vtln_warp"));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace kaldifeat
|
} // namespace kaldifeat
|
||||||
|
@ -14,14 +14,9 @@
|
|||||||
|
|
||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
|
|
||||||
static torch::Tensor Compute(const torch::Tensor &wave,
|
static torch::Tensor Compute(const torch::Tensor &wave, Fbank *fbank) {
|
||||||
const FbankOptions &fbank_opts) {
|
|
||||||
// TODO(fangjun): wrap Fbank to Python
|
|
||||||
|
|
||||||
Fbank fbank(fbank_opts);
|
|
||||||
float vtln_warp = 1.0f;
|
float vtln_warp = 1.0f;
|
||||||
|
torch::Tensor ans = fbank->ComputeFeatures(wave, vtln_warp);
|
||||||
torch::Tensor ans = fbank.ComputeFeatures(wave, vtln_warp);
|
|
||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -32,18 +27,18 @@ PYBIND11_MODULE(_kaldifeat, m) {
|
|||||||
PybindMelBanksOptions(m);
|
PybindMelBanksOptions(m);
|
||||||
PybindFbankOptions(m);
|
PybindFbankOptions(m);
|
||||||
|
|
||||||
m.def("compute", &Compute, py::arg("wave"), py::arg("fbank_opts"));
|
m.def("compute", &Compute, py::arg("wave"), py::arg("fbank"));
|
||||||
|
|
||||||
// It verifies that the reimplementation produces the same output
|
// It verifies that the reimplementation produces the same output
|
||||||
// as kaldi using default parameters with dither disabled.
|
// as kaldi using default parameters with dither disabled.
|
||||||
m.def(
|
m.def(
|
||||||
"_compute_with_elapsed_time", // for benchmark only
|
"_compute_with_elapsed_time", // for benchmark only
|
||||||
[](const torch::Tensor &wave,
|
[](const torch::Tensor &wave,
|
||||||
const FbankOptions &fbank_opts) -> std::pair<torch::Tensor, double> {
|
Fbank *fbank) -> std::pair<torch::Tensor, double> {
|
||||||
std::chrono::steady_clock::time_point begin =
|
std::chrono::steady_clock::time_point begin =
|
||||||
std::chrono::steady_clock::now();
|
std::chrono::steady_clock::now();
|
||||||
|
|
||||||
torch::Tensor ans = Compute(wave, fbank_opts);
|
torch::Tensor ans = Compute(wave, fbank);
|
||||||
|
|
||||||
std::chrono::steady_clock::time_point end =
|
std::chrono::steady_clock::time_point end =
|
||||||
std::chrono::steady_clock::now();
|
std::chrono::steady_clock::now();
|
||||||
@ -55,7 +50,7 @@ PYBIND11_MODULE(_kaldifeat, m) {
|
|||||||
|
|
||||||
return std::make_pair(ans, elapsed_seconds);
|
return std::make_pair(ans, elapsed_seconds);
|
||||||
},
|
},
|
||||||
py::arg("wave"), py::arg("fbank_opts"));
|
py::arg("wave"), py::arg("fbank"));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace kaldifeat
|
} // namespace kaldifeat
|
||||||
|
@ -58,11 +58,11 @@ def read_wave() -> torch.Tensor:
|
|||||||
def test_and_benchmark_default_parameters():
|
def test_and_benchmark_default_parameters():
|
||||||
fbank_opts = _kaldifeat.FbankOptions()
|
fbank_opts = _kaldifeat.FbankOptions()
|
||||||
fbank_opts.frame_opts.dither = 0
|
fbank_opts.frame_opts.dither = 0
|
||||||
|
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||||
|
|
||||||
data = read_wave()
|
data = read_wave()
|
||||||
|
|
||||||
ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(
|
ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(data, fbank)
|
||||||
data, fbank_opts)
|
|
||||||
|
|
||||||
expected = read_ark_txt()
|
expected = read_ark_txt()
|
||||||
assert torch.allclose(ans, expected, rtol=1e-3)
|
assert torch.allclose(ans, expected, rtol=1e-3)
|
||||||
@ -74,10 +74,11 @@ def test_use_energy_htk_compat_true():
|
|||||||
fbank_opts.frame_opts.dither = 0
|
fbank_opts.frame_opts.dither = 0
|
||||||
fbank_opts.use_energy = True
|
fbank_opts.use_energy = True
|
||||||
fbank_opts.htk_compat = True
|
fbank_opts.htk_compat = True
|
||||||
|
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||||
|
|
||||||
data = read_wave()
|
data = read_wave()
|
||||||
|
|
||||||
ans = _kaldifeat.compute(data, fbank_opts)
|
ans = _kaldifeat.compute(data, fbank)
|
||||||
|
|
||||||
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:abc.scp ark,t:abc.txt
|
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:abc.scp ark,t:abc.txt
|
||||||
# the first 3 rows are:
|
# the first 3 rows are:
|
||||||
@ -95,10 +96,11 @@ def test_use_energy_htk_compat_false():
|
|||||||
fbank_opts.frame_opts.dither = 0
|
fbank_opts.frame_opts.dither = 0
|
||||||
fbank_opts.use_energy = True
|
fbank_opts.use_energy = True
|
||||||
fbank_opts.htk_compat = False
|
fbank_opts.htk_compat = False
|
||||||
|
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||||
|
|
||||||
data = read_wave()
|
data = read_wave()
|
||||||
|
|
||||||
ans = _kaldifeat.compute(data, fbank_opts)
|
ans = _kaldifeat.compute(data, fbank)
|
||||||
|
|
||||||
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=0 scp:abc.scp ark,t:abc.txt
|
# ./compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=0 scp:abc.scp ark,t:abc.txt
|
||||||
# the first 3 rows are:
|
# the first 3 rows are:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user