Support batch processing.

This commit is contained in:
Fangjun Kuang 2021-02-28 00:07:58 +08:00
parent f909f839ab
commit efcd8ab92c
7 changed files with 53 additions and 3 deletions

View File

@ -14,11 +14,16 @@ namespace kaldifeat {
template <class F>
torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
float vtln_warp) {
KALDIFEAT_ASSERT(wave.dim() == 1);
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
torch::Tensor strided_input = GetStrided(wave, frame_opts);
torch::Tensor strided_input;
if (wave.dim() == 1) {
strided_input = GetStrided(wave, frame_opts);
} else {
KALDIFEAT_ASSERT(wave.dim() == 2);
KALDIFEAT_ASSERT(wave.sizes()[1] == frame_opts.WindowSize());
strided_input = wave;
}
if (frame_opts.dither != 0.0f) {
strided_input = Dither(strided_input, frame_opts.dither);

View File

@ -40,6 +40,7 @@ class OfflineFeatureTpl {
torch::Tensor ComputeFeatures(const torch::Tensor &wave, float vtln_warp);
int32_t Dim() const { return computer_.Dim(); }
const Options &GetOptions() const { return computer_.GetOptions(); }
// Copy constructor.
OfflineFeatureTpl(const OfflineFeatureTpl<F> &) = delete;

View File

@ -85,6 +85,8 @@ class FbankComputer {
return opts_.frame_opts;
}
const FbankOptions &GetOptions() const { return opts_; }
// signal_raw_log_energy is log_energy_pre_window, which is not empty
// iff NeedRawLogEnergy() returns true.
torch::Tensor Compute(torch::Tensor signal_raw_log_energy, float vtln_warp,

View File

@ -67,10 +67,12 @@ static void TestCat() {
torch::Tensor b = torch::arange(0, 2).reshape({2, 1}).to(torch::kFloat) * 0.1;
torch::Tensor c = torch::cat({a, b}, 1);
torch::Tensor d = torch::cat({b, a}, 1);
torch::Tensor e = torch::cat({a, a}, 0);
std::cout << a << "\n";
std::cout << b << "\n";
std::cout << c << "\n";
std::cout << d << "\n";
std::cout << e << "\n";
}
int main() {

View File

@ -42,6 +42,8 @@ void PybindFbankOptions(py::module &m) {
py::class_<Fbank>(m, "Fbank")
.def(py::init<const FbankOptions &>(), py::arg("opts"))
.def("dim", &Fbank::Dim)
.def("options", &Fbank::GetOptions,
py::return_value_policy::reference_internal)
.def("compute_features", &Fbank::ComputeFeatures, py::arg("wave"),
py::arg("vtln_warp"));
}

View File

@ -34,6 +34,11 @@ void PybindFrameExtractionOptions(py::module &m) {
.def("__str__", [](const FrameExtractionOptions &self) -> std::string {
return self.ToString();
});
m.def("num_frames", &NumFrames, py::arg("num_samples"), py::arg("opts"),
py::arg("flush") = true);
m.def("get_strided", &GetStrided, py::arg("ave"), py::arg("opts"));
}
} // namespace kaldifeat

View File

@ -6,6 +6,8 @@ from pathlib import Path
cur_dir = Path(__file__).resolve().parent
kaldi_feat_dir = cur_dir.parent.parent.parent
from typing import List
import sys
sys.path.insert(0, f'{kaldi_feat_dir}/build/lib')
@ -132,11 +134,42 @@ def test_use_energy_htk_compat_false():
assert torch.allclose(ans[:3, :].cpu(), expected, rtol=1e-3)
def test_compute_batch():
data1 = read_wave()
data2 = read_wave()
data = [data1, data2]
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank = _kaldifeat.Fbank(fbank_opts)
def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]:
num_frames = [
_kaldifeat.num_frames(w.numel(), fbank_opts.frame_opts)
for w in waves
]
strided = [
_kaldifeat.get_strided(w, fbank_opts.frame_opts) for w in waves
]
strided = torch.cat(strided, dim=0)
features = _kaldifeat.compute(strided, fbank)
feature1 = features[:num_frames[0]]
feature2 = features[num_frames[0]:]
return [feature1, feature2]
feature1, feature2 = impl([data1, data2])
assert torch.allclose(feature1, feature2)
def main():
test_and_benchmark_default_parameters()
test_use_energy_htk_compat_true()
test_use_energy_htk_compat_false()
test_compute_batch()
if __name__ == '__main__':
main()