mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 01:52:39 +00:00
Support batch processing.
This commit is contained in:
parent
f909f839ab
commit
efcd8ab92c
@ -14,11 +14,16 @@ namespace kaldifeat {
|
||||
template <class F>
|
||||
torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
|
||||
float vtln_warp) {
|
||||
KALDIFEAT_ASSERT(wave.dim() == 1);
|
||||
|
||||
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
|
||||
|
||||
torch::Tensor strided_input = GetStrided(wave, frame_opts);
|
||||
torch::Tensor strided_input;
|
||||
if (wave.dim() == 1) {
|
||||
strided_input = GetStrided(wave, frame_opts);
|
||||
} else {
|
||||
KALDIFEAT_ASSERT(wave.dim() == 2);
|
||||
KALDIFEAT_ASSERT(wave.sizes()[1] == frame_opts.WindowSize());
|
||||
strided_input = wave;
|
||||
}
|
||||
|
||||
if (frame_opts.dither != 0.0f) {
|
||||
strided_input = Dither(strided_input, frame_opts.dither);
|
||||
|
@ -40,6 +40,7 @@ class OfflineFeatureTpl {
|
||||
torch::Tensor ComputeFeatures(const torch::Tensor &wave, float vtln_warp);
|
||||
|
||||
int32_t Dim() const { return computer_.Dim(); }
|
||||
const Options &GetOptions() const { return computer_.GetOptions(); }
|
||||
|
||||
// Copy constructor.
|
||||
OfflineFeatureTpl(const OfflineFeatureTpl<F> &) = delete;
|
||||
|
@ -85,6 +85,8 @@ class FbankComputer {
|
||||
return opts_.frame_opts;
|
||||
}
|
||||
|
||||
const FbankOptions &GetOptions() const { return opts_; }
|
||||
|
||||
// signal_raw_log_energy is log_energy_pre_window, which is not empty
|
||||
// iff NeedRawLogEnergy() returns true.
|
||||
torch::Tensor Compute(torch::Tensor signal_raw_log_energy, float vtln_warp,
|
||||
|
@ -67,10 +67,12 @@ static void TestCat() {
|
||||
torch::Tensor b = torch::arange(0, 2).reshape({2, 1}).to(torch::kFloat) * 0.1;
|
||||
torch::Tensor c = torch::cat({a, b}, 1);
|
||||
torch::Tensor d = torch::cat({b, a}, 1);
|
||||
torch::Tensor e = torch::cat({a, a}, 0);
|
||||
std::cout << a << "\n";
|
||||
std::cout << b << "\n";
|
||||
std::cout << c << "\n";
|
||||
std::cout << d << "\n";
|
||||
std::cout << e << "\n";
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
@ -42,6 +42,8 @@ void PybindFbankOptions(py::module &m) {
|
||||
py::class_<Fbank>(m, "Fbank")
|
||||
.def(py::init<const FbankOptions &>(), py::arg("opts"))
|
||||
.def("dim", &Fbank::Dim)
|
||||
.def("options", &Fbank::GetOptions,
|
||||
py::return_value_policy::reference_internal)
|
||||
.def("compute_features", &Fbank::ComputeFeatures, py::arg("wave"),
|
||||
py::arg("vtln_warp"));
|
||||
}
|
||||
|
@ -34,6 +34,11 @@ void PybindFrameExtractionOptions(py::module &m) {
|
||||
.def("__str__", [](const FrameExtractionOptions &self) -> std::string {
|
||||
return self.ToString();
|
||||
});
|
||||
|
||||
m.def("num_frames", &NumFrames, py::arg("num_samples"), py::arg("opts"),
|
||||
py::arg("flush") = true);
|
||||
|
||||
m.def("get_strided", &GetStrided, py::arg("ave"), py::arg("opts"));
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
@ -6,6 +6,8 @@ from pathlib import Path
|
||||
cur_dir = Path(__file__).resolve().parent
|
||||
kaldi_feat_dir = cur_dir.parent.parent.parent
|
||||
|
||||
from typing import List
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, f'{kaldi_feat_dir}/build/lib')
|
||||
|
||||
@ -132,11 +134,42 @@ def test_use_energy_htk_compat_false():
|
||||
assert torch.allclose(ans[:3, :].cpu(), expected, rtol=1e-3)
|
||||
|
||||
|
||||
def test_compute_batch():
|
||||
data1 = read_wave()
|
||||
data2 = read_wave()
|
||||
|
||||
data = [data1, data2]
|
||||
fbank_opts = _kaldifeat.FbankOptions()
|
||||
fbank_opts.frame_opts.dither = 0
|
||||
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||
|
||||
def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
num_frames = [
|
||||
_kaldifeat.num_frames(w.numel(), fbank_opts.frame_opts)
|
||||
for w in waves
|
||||
]
|
||||
|
||||
strided = [
|
||||
_kaldifeat.get_strided(w, fbank_opts.frame_opts) for w in waves
|
||||
]
|
||||
strided = torch.cat(strided, dim=0)
|
||||
|
||||
features = _kaldifeat.compute(strided, fbank)
|
||||
feature1 = features[:num_frames[0]]
|
||||
feature2 = features[num_frames[0]:]
|
||||
return [feature1, feature2]
|
||||
|
||||
feature1, feature2 = impl([data1, data2])
|
||||
assert torch.allclose(feature1, feature2)
|
||||
|
||||
|
||||
def main():
|
||||
test_and_benchmark_default_parameters()
|
||||
test_use_energy_htk_compat_true()
|
||||
test_use_energy_htk_compat_false()
|
||||
|
||||
test_compute_batch()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user