mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 01:52:39 +00:00
Support batch processing.
This commit is contained in:
parent
f909f839ab
commit
efcd8ab92c
@ -14,11 +14,16 @@ namespace kaldifeat {
|
|||||||
template <class F>
|
template <class F>
|
||||||
torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
|
torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
|
||||||
float vtln_warp) {
|
float vtln_warp) {
|
||||||
KALDIFEAT_ASSERT(wave.dim() == 1);
|
|
||||||
|
|
||||||
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
|
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
|
||||||
|
|
||||||
torch::Tensor strided_input = GetStrided(wave, frame_opts);
|
torch::Tensor strided_input;
|
||||||
|
if (wave.dim() == 1) {
|
||||||
|
strided_input = GetStrided(wave, frame_opts);
|
||||||
|
} else {
|
||||||
|
KALDIFEAT_ASSERT(wave.dim() == 2);
|
||||||
|
KALDIFEAT_ASSERT(wave.sizes()[1] == frame_opts.WindowSize());
|
||||||
|
strided_input = wave;
|
||||||
|
}
|
||||||
|
|
||||||
if (frame_opts.dither != 0.0f) {
|
if (frame_opts.dither != 0.0f) {
|
||||||
strided_input = Dither(strided_input, frame_opts.dither);
|
strided_input = Dither(strided_input, frame_opts.dither);
|
||||||
|
@ -40,6 +40,7 @@ class OfflineFeatureTpl {
|
|||||||
torch::Tensor ComputeFeatures(const torch::Tensor &wave, float vtln_warp);
|
torch::Tensor ComputeFeatures(const torch::Tensor &wave, float vtln_warp);
|
||||||
|
|
||||||
int32_t Dim() const { return computer_.Dim(); }
|
int32_t Dim() const { return computer_.Dim(); }
|
||||||
|
const Options &GetOptions() const { return computer_.GetOptions(); }
|
||||||
|
|
||||||
// Copy constructor.
|
// Copy constructor.
|
||||||
OfflineFeatureTpl(const OfflineFeatureTpl<F> &) = delete;
|
OfflineFeatureTpl(const OfflineFeatureTpl<F> &) = delete;
|
||||||
|
@ -85,6 +85,8 @@ class FbankComputer {
|
|||||||
return opts_.frame_opts;
|
return opts_.frame_opts;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const FbankOptions &GetOptions() const { return opts_; }
|
||||||
|
|
||||||
// signal_raw_log_energy is log_energy_pre_window, which is not empty
|
// signal_raw_log_energy is log_energy_pre_window, which is not empty
|
||||||
// iff NeedRawLogEnergy() returns true.
|
// iff NeedRawLogEnergy() returns true.
|
||||||
torch::Tensor Compute(torch::Tensor signal_raw_log_energy, float vtln_warp,
|
torch::Tensor Compute(torch::Tensor signal_raw_log_energy, float vtln_warp,
|
||||||
|
@ -67,10 +67,12 @@ static void TestCat() {
|
|||||||
torch::Tensor b = torch::arange(0, 2).reshape({2, 1}).to(torch::kFloat) * 0.1;
|
torch::Tensor b = torch::arange(0, 2).reshape({2, 1}).to(torch::kFloat) * 0.1;
|
||||||
torch::Tensor c = torch::cat({a, b}, 1);
|
torch::Tensor c = torch::cat({a, b}, 1);
|
||||||
torch::Tensor d = torch::cat({b, a}, 1);
|
torch::Tensor d = torch::cat({b, a}, 1);
|
||||||
|
torch::Tensor e = torch::cat({a, a}, 0);
|
||||||
std::cout << a << "\n";
|
std::cout << a << "\n";
|
||||||
std::cout << b << "\n";
|
std::cout << b << "\n";
|
||||||
std::cout << c << "\n";
|
std::cout << c << "\n";
|
||||||
std::cout << d << "\n";
|
std::cout << d << "\n";
|
||||||
|
std::cout << e << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
|
@ -42,6 +42,8 @@ void PybindFbankOptions(py::module &m) {
|
|||||||
py::class_<Fbank>(m, "Fbank")
|
py::class_<Fbank>(m, "Fbank")
|
||||||
.def(py::init<const FbankOptions &>(), py::arg("opts"))
|
.def(py::init<const FbankOptions &>(), py::arg("opts"))
|
||||||
.def("dim", &Fbank::Dim)
|
.def("dim", &Fbank::Dim)
|
||||||
|
.def("options", &Fbank::GetOptions,
|
||||||
|
py::return_value_policy::reference_internal)
|
||||||
.def("compute_features", &Fbank::ComputeFeatures, py::arg("wave"),
|
.def("compute_features", &Fbank::ComputeFeatures, py::arg("wave"),
|
||||||
py::arg("vtln_warp"));
|
py::arg("vtln_warp"));
|
||||||
}
|
}
|
||||||
|
@ -34,6 +34,11 @@ void PybindFrameExtractionOptions(py::module &m) {
|
|||||||
.def("__str__", [](const FrameExtractionOptions &self) -> std::string {
|
.def("__str__", [](const FrameExtractionOptions &self) -> std::string {
|
||||||
return self.ToString();
|
return self.ToString();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
m.def("num_frames", &NumFrames, py::arg("num_samples"), py::arg("opts"),
|
||||||
|
py::arg("flush") = true);
|
||||||
|
|
||||||
|
m.def("get_strided", &GetStrided, py::arg("ave"), py::arg("opts"));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace kaldifeat
|
} // namespace kaldifeat
|
||||||
|
@ -6,6 +6,8 @@ from pathlib import Path
|
|||||||
cur_dir = Path(__file__).resolve().parent
|
cur_dir = Path(__file__).resolve().parent
|
||||||
kaldi_feat_dir = cur_dir.parent.parent.parent
|
kaldi_feat_dir = cur_dir.parent.parent.parent
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.path.insert(0, f'{kaldi_feat_dir}/build/lib')
|
sys.path.insert(0, f'{kaldi_feat_dir}/build/lib')
|
||||||
|
|
||||||
@ -132,11 +134,42 @@ def test_use_energy_htk_compat_false():
|
|||||||
assert torch.allclose(ans[:3, :].cpu(), expected, rtol=1e-3)
|
assert torch.allclose(ans[:3, :].cpu(), expected, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_batch():
|
||||||
|
data1 = read_wave()
|
||||||
|
data2 = read_wave()
|
||||||
|
|
||||||
|
data = [data1, data2]
|
||||||
|
fbank_opts = _kaldifeat.FbankOptions()
|
||||||
|
fbank_opts.frame_opts.dither = 0
|
||||||
|
fbank = _kaldifeat.Fbank(fbank_opts)
|
||||||
|
|
||||||
|
def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||||
|
num_frames = [
|
||||||
|
_kaldifeat.num_frames(w.numel(), fbank_opts.frame_opts)
|
||||||
|
for w in waves
|
||||||
|
]
|
||||||
|
|
||||||
|
strided = [
|
||||||
|
_kaldifeat.get_strided(w, fbank_opts.frame_opts) for w in waves
|
||||||
|
]
|
||||||
|
strided = torch.cat(strided, dim=0)
|
||||||
|
|
||||||
|
features = _kaldifeat.compute(strided, fbank)
|
||||||
|
feature1 = features[:num_frames[0]]
|
||||||
|
feature2 = features[num_frames[0]:]
|
||||||
|
return [feature1, feature2]
|
||||||
|
|
||||||
|
feature1, feature2 = impl([data1, data2])
|
||||||
|
assert torch.allclose(feature1, feature2)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
test_and_benchmark_default_parameters()
|
test_and_benchmark_default_parameters()
|
||||||
test_use_energy_htk_compat_true()
|
test_use_energy_htk_compat_true()
|
||||||
test_use_energy_htk_compat_false()
|
test_use_energy_htk_compat_false()
|
||||||
|
|
||||||
|
test_compute_batch()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user