From e59d05a45a49c5ead288cd346f7610f4ed2003af Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 2 Apr 2022 20:03:42 +0800 Subject: [PATCH] Add OnlineFbank python APIs. --- kaldifeat/csrc/feature-window.cc | 1 + kaldifeat/csrc/online-feature-itf.h | 6 +- kaldifeat/csrc/online-feature.cc | 8 +- kaldifeat/python/csrc/CMakeLists.txt | 1 + kaldifeat/python/csrc/kaldifeat.cc | 2 + kaldifeat/python/csrc/online-feature.cc | 36 +++++ kaldifeat/python/csrc/online-feature.h | 16 ++ kaldifeat/python/kaldifeat/__init__.py | 2 +- kaldifeat/python/kaldifeat/fbank.py | 11 ++ kaldifeat/python/kaldifeat/online_feature.py | 95 ++++++++++++ kaldifeat/python/tests/test_fbank.py | 153 +++++++++++++++---- 11 files changed, 296 insertions(+), 35 deletions(-) create mode 100644 kaldifeat/python/csrc/online-feature.cc create mode 100644 kaldifeat/python/csrc/online-feature.h create mode 100644 kaldifeat/python/kaldifeat/online_feature.py diff --git a/kaldifeat/csrc/feature-window.cc b/kaldifeat/csrc/feature-window.cc index 5d25720..6880f7e 100644 --- a/kaldifeat/csrc/feature-window.cc +++ b/kaldifeat/csrc/feature-window.cc @@ -240,6 +240,7 @@ torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave, p_window[s] = p_wave[s_in_wave]; } + return window; } } diff --git a/kaldifeat/csrc/online-feature-itf.h b/kaldifeat/csrc/online-feature-itf.h index 4dc779a..60af265 100644 --- a/kaldifeat/csrc/online-feature-itf.h +++ b/kaldifeat/csrc/online-feature-itf.h @@ -56,7 +56,8 @@ class OnlineFeatureInterface { /// it's more efficient to do things in a batch). /// /// The returned tensor has shape (frames.size(), Dim()). - virtual torch::Tensor GetFrames(const std::vector &frames) { + virtual std::vector GetFrames( + const std::vector &frames) { std::vector features; features.reserve(frames.size()); @@ -64,8 +65,9 @@ class OnlineFeatureInterface { torch::Tensor f = GetFrame(i); features.push_back(std::move(f)); } + return features; - return torch::cat(features, /*dim*/ 0); + // return torch::cat(features, [>dim<] 0); } /// This would be called from the application, when you get more wave data. diff --git a/kaldifeat/csrc/online-feature.cc b/kaldifeat/csrc/online-feature.cc index 42855f4..43fc1b1 100644 --- a/kaldifeat/csrc/online-feature.cc +++ b/kaldifeat/csrc/online-feature.cc @@ -48,12 +48,7 @@ OnlineGenericBaseFeature::OnlineGenericBaseFeature( window_function_(opts.frame_opts, opts.device), features_(opts.frame_opts.max_feature_vectors), input_finished_(false), - waveform_offset_(0) { - // Casting to uint32_t, an unsigned type, means that -1 would be treated - // as `very large`. - KALDIFEAT_ASSERT(static_cast(opts.frame_opts.max_feature_vectors) > - 200); -} + waveform_offset_(0) {} template void OnlineGenericBaseFeature::AcceptWaveform( @@ -61,6 +56,7 @@ void OnlineGenericBaseFeature::AcceptWaveform( if (original_waveform.numel() == 0) return; // Nothing to do. KALDIFEAT_ASSERT(original_waveform.dim() == 1); + KALDIFEAT_ASSERT(sampling_rate == computer_.GetFrameOptions().samp_freq); if (input_finished_) KALDIFEAT_ERR << "AcceptWaveform called after InputFinished() was called."; diff --git a/kaldifeat/python/csrc/CMakeLists.txt b/kaldifeat/python/csrc/CMakeLists.txt index affb69c..956263f 100644 --- a/kaldifeat/python/csrc/CMakeLists.txt +++ b/kaldifeat/python/csrc/CMakeLists.txt @@ -7,6 +7,7 @@ pybind11_add_module(_kaldifeat feature-window.cc kaldifeat.cc mel-computations.cc + online-feature.cc utils.cc ) target_link_libraries(_kaldifeat PRIVATE kaldifeat_core) diff --git a/kaldifeat/python/csrc/kaldifeat.cc b/kaldifeat/python/csrc/kaldifeat.cc index 93e66ac..0a4b8c2 100644 --- a/kaldifeat/python/csrc/kaldifeat.cc +++ b/kaldifeat/python/csrc/kaldifeat.cc @@ -11,6 +11,7 @@ #include "kaldifeat/python/csrc/feature-spectrogram.h" #include "kaldifeat/python/csrc/feature-window.h" #include "kaldifeat/python/csrc/mel-computations.h" +#include "kaldifeat/python/csrc/online-feature.h" #include "torch/torch.h" namespace kaldifeat { @@ -24,6 +25,7 @@ PYBIND11_MODULE(_kaldifeat, m) { PybindFeatureMfcc(m); PybindFeaturePlp(m); PybindFeatureSpectrogram(m); + PybindOnlineFeature(m); } } // namespace kaldifeat diff --git a/kaldifeat/python/csrc/online-feature.cc b/kaldifeat/python/csrc/online-feature.cc new file mode 100644 index 0000000..9592a96 --- /dev/null +++ b/kaldifeat/python/csrc/online-feature.cc @@ -0,0 +1,36 @@ +// kaldifeat/python/csrc/online-feature.cc +// +// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +#include "kaldifeat/python/csrc/online-feature.h" + +#include "kaldifeat/csrc/online-feature.h" + +namespace kaldifeat { + +template +void PybindOnlineFeatureTpl(py::module &m, const std::string &class_name, + const std::string &class_help_doc = "") { + using PyClass = OnlineGenericBaseFeature; + using Options = typename C::Options; + py::class_(m, class_name.c_str(), class_help_doc.c_str()) + .def(py::init(), py::arg("opts")) + .def_property_readonly("dim", &PyClass::Dim) + .def_property_readonly("frame_shift_in_seconds", + &PyClass::FrameShiftInSeconds) + .def_property_readonly("num_frames_ready", &PyClass::NumFramesReady) + .def("is_last_frame", &PyClass::IsLastFrame, py::arg("frame")) + .def("get_frame", &PyClass::GetFrame, py::arg("frame")) + .def("get_frames", &PyClass::GetFrames, py::arg("frames")) + .def("accept_waveform", &PyClass::AcceptWaveform, + py::arg("sampling_rate"), py::arg("waveform")) + .def("input_finished", &PyClass::InputFinished); +} + +void PybindOnlineFeature(py::module &m) { + PybindOnlineFeatureTpl(m, "OnlineMfcc"); + PybindOnlineFeatureTpl(m, "OnlineFbank"); + PybindOnlineFeatureTpl(m, "OnlinePlp"); +} + +} // namespace kaldifeat diff --git a/kaldifeat/python/csrc/online-feature.h b/kaldifeat/python/csrc/online-feature.h new file mode 100644 index 0000000..c363f42 --- /dev/null +++ b/kaldifeat/python/csrc/online-feature.h @@ -0,0 +1,16 @@ +// kaldifeat/python/csrc/online-feature.h +// +// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +#ifndef KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_ +#define KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_ + +#include "kaldifeat/python/csrc/kaldifeat.h" + +namespace kaldifeat { + +void PybindOnlineFeature(py::module &m); + +} // namespace kaldifeat + +#endif // KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_ diff --git a/kaldifeat/python/kaldifeat/__init__.py b/kaldifeat/python/kaldifeat/__init__.py index 6b2f088..57004b7 100644 --- a/kaldifeat/python/kaldifeat/__init__.py +++ b/kaldifeat/python/kaldifeat/__init__.py @@ -8,7 +8,7 @@ from _kaldifeat import ( SpectrogramOptions, ) -from .fbank import Fbank +from .fbank import Fbank, OnlineFbank from .mfcc import Mfcc from .plp import Plp from .spectrogram import Spectrogram diff --git a/kaldifeat/python/kaldifeat/fbank.py b/kaldifeat/python/kaldifeat/fbank.py index 8f73911..275d1cf 100644 --- a/kaldifeat/python/kaldifeat/fbank.py +++ b/kaldifeat/python/kaldifeat/fbank.py @@ -4,9 +4,20 @@ import _kaldifeat from .offline_feature import OfflineFeature +from .online_feature import OnlineFeature class Fbank(OfflineFeature): def __init__(self, opts: _kaldifeat.FbankOptions): super().__init__(opts) self.computer = _kaldifeat.Fbank(opts) + + +class OnlineFbank(OnlineFeature): + def __init__(self, opts: _kaldifeat.FbankOptions): + super().__init__(opts) + self.computer = _kaldifeat.OnlineFbank(opts) + + def __setstate__(self, state): + self.opts = _kaldifeat.FbankOptions.from_dict(state) + self.computer = _kaldifeat.Fbank(self.opts) diff --git a/kaldifeat/python/kaldifeat/online_feature.py b/kaldifeat/python/kaldifeat/online_feature.py new file mode 100644 index 0000000..cf687cf --- /dev/null +++ b/kaldifeat/python/kaldifeat/online_feature.py @@ -0,0 +1,95 @@ +# Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +from typing import List + +import torch + + +class OnlineFeature(object): + """Offline feature is a base class of other feature computers, + e.g., Fbank, Mfcc. + + This class has two fields: + + (1) opts. It contains the options for the feature computer. + (2) computer. The actual feature computer. It should be + instantiated by subclasses. + + Caution: + It supports only CPU at present. + """ + + def __init__(self, opts): + assert opts.device.type == "cpu" + + self.opts = opts + + # self.computer is expected to be set by subclasses + self.computer = None + + @property + def num_frames_ready(self) -> int: + """Return the number of ready frames. + + It can be updated by :method:`accept_waveform`. + + Note: + If you set ``opts.frame_opts.max_feature_vectors``, then + the valid frame indexes are in the range. + ``[num_frames_ready - max_feature_vectors, num_frames_ready)`` + + If you leave ``opts.frame_opts.max_feature_vectors`` to its default + value, then the range is ``[0, num_frames_ready)`` + """ + return self.computer.num_frames_ready + + def is_last_frame(self, frame: int) -> bool: + """Return True if the given frame is the last frame.""" + return self.computer.is_last_frame(frame) + + def get_frame(self, frame: int) -> torch.Tensor: + """Get the frame by its index. + Args: + frame: + The frame index. If ``opts.frame_opts.max_feature_vectors`` is + -1, then its valid values are in the range + ``[0, num_frames_ready)``. Otherwise, the range is + ``[num_frames_ready - max_feature_vectors, num_frames_ready)``. + Returns: + Return a 2-D tensor with shape ``(1, feature_dim)`` + """ + return self.computer.get_frame(frame) + + def get_frames(self, frames: List[int]) -> List[torch.Tensor]: + """Get frames at the given frame indexes. + Args: + frames: + Frames whose indexes are in this list are returned. + Returns: + Return a list of feature frames at the given indexes. + """ + return self.computer.get_frames(frames) + + def accept_waveform( + self, sampling_rate: float, waveform: torch.Tensor + ) -> None: + """Send audio samples to the extractor. + Args: + sampling_rate: + The sampling rate of the given audio samples. It has to be equal + to ``opts.frame_opts.samp_freq``. + waveform: + A 1-D tensor of shape (num_samples,). Its dtype is torch.float32 + and has to be on CPU. + """ + self.computer.accept_waveform(sampling_rate, waveform) + + def input_finished(self) -> None: + """Tell the extractor that no more audio samples will be available. + After calling this function, you cannot invoke ``accept_waveform`` + again. + """ + self.computer.input_finished() + + def __getstate__(self): + return self.opts.as_dict() diff --git a/kaldifeat/python/tests/test_fbank.py b/kaldifeat/python/tests/test_fbank.py index 57092b6..1c06438 100755 --- a/kaldifeat/python/tests/test_fbank.py +++ b/kaldifeat/python/tests/test_fbank.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang) import pickle from pathlib import Path @@ -13,30 +13,88 @@ import kaldifeat cur_dir = Path(__file__).resolve().parent +def test_online_fbank( + opts: kaldifeat.FbankOptions, + wave: torch.Tensor, + cpu_features: torch.Tensor, +): + """ + Args: + opts: + The options to create the online fbank extractor. + wave: + The input 1-D waveform. + cpu_features: + The groud truth features that are computed offline + """ + online_fbank = kaldifeat.OnlineFbank(opts) + + num_processed_frames = 0 + i = 0 # current sample index to feed + while not online_fbank.is_last_frame(num_processed_frames - 1): + while num_processed_frames < online_fbank.num_frames_ready: + # There are new frames to be processed + frame = online_fbank.get_frame(num_processed_frames) + assert torch.allclose( + frame.squeeze(0), cpu_features[num_processed_frames] + ) + num_processed_frames += 1 + + # Simulate streaming . Send a random number of audio samples + # to the extractor + num_samples = torch.randint(300, 1000, (1,)).item() + + samples = wave[i : (i + num_samples)] # noqa + i += num_samples + if len(samples) == 0: + online_fbank.input_finished() + continue + + online_fbank.accept_waveform(16000, samples) + + assert num_processed_frames == online_fbank.num_frames_ready + assert num_processed_frames == cpu_features.size(0) + + def test_fbank_default(): print("=====test_fbank_default=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.FbankOptions() opts.device = device opts.frame_opts.dither = 0 fbank = kaldifeat.Fbank(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename) features = fbank(wave) assert features.device.type == "cpu" - gt = read_ark_txt(cur_dir / "test_data/test.txt") assert torch.allclose(features, gt, rtol=1e-1) + if cpu_features is None: + cpu_features = features - wave = wave.to(device) - features = fbank(wave) + features = fbank(wave.to(device)) assert features.device == device assert torch.allclose(features.cpu(), gt, rtol=1e-1) + # Now for online fbank + opts = kaldifeat.FbankOptions() + opts.frame_opts.dither = 0 + opts.frame_opts.max_feature_vectors = 100 + + test_online_fbank(opts, wave, cpu_features) + def test_fbank_htk(): print("=====test_fbank_htk=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-htk.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.FbankOptions() @@ -46,22 +104,32 @@ def test_fbank_htk(): opts.htk_compat = True fbank = kaldifeat.Fbank(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename) features = fbank(wave) assert features.device.type == "cpu" - gt = read_ark_txt(cur_dir / "test_data/test-htk.txt") assert torch.allclose(features, gt, rtol=1e-1) + if cpu_features is None: + cpu_features = features - wave = wave.to(device) - features = fbank(wave) + features = fbank(wave.to(device)) assert features.device == device assert torch.allclose(features.cpu(), gt, rtol=1e-1) + opts = kaldifeat.FbankOptions() + opts.frame_opts.dither = 0 + opts.use_energy = True + opts.htk_compat = True + + test_online_fbank(opts, wave, cpu_features) + def test_fbank_with_energy(): print("=====test_fbank_with_energy=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.FbankOptions() @@ -70,22 +138,31 @@ def test_fbank_with_energy(): opts.use_energy = True fbank = kaldifeat.Fbank(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename) features = fbank(wave) - gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt") assert torch.allclose(features, gt, rtol=1e-1) assert features.device.type == "cpu" + if cpu_features is None: + cpu_features = features - wave = wave.to(device) - features = fbank(wave) + features = fbank(wave.to(device)) assert features.device == device assert torch.allclose(features.cpu(), gt, rtol=1e-1) + opts = kaldifeat.FbankOptions() + opts.frame_opts.dither = 0 + opts.use_energy = True + + test_online_fbank(opts, wave, cpu_features) + def test_fbank_40_bins(): print("=====test_fbank_40_bins=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-40.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.FbankOptions() @@ -94,22 +171,31 @@ def test_fbank_40_bins(): opts.mel_opts.num_bins = 40 fbank = kaldifeat.Fbank(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename) features = fbank(wave) assert features.device.type == "cpu" - gt = read_ark_txt(cur_dir / "test_data/test-40.txt") assert torch.allclose(features, gt, rtol=1e-1) + if cpu_features is None: + cpu_features = features - wave = wave.to(device) - features = fbank(wave) + features = fbank(wave.to(device)) assert features.device == device assert torch.allclose(features.cpu(), gt, rtol=1e-1) + opts = kaldifeat.FbankOptions() + opts.frame_opts.dither = 0 + opts.mel_opts.num_bins = 40 + + test_online_fbank(opts, wave, cpu_features) + def test_fbank_40_bins_no_snip_edges(): print("=====test_fbank_40_bins_no_snip_edges=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.FbankOptions() @@ -119,19 +205,24 @@ def test_fbank_40_bins_no_snip_edges(): opts.frame_opts.snip_edges = False fbank = kaldifeat.Fbank(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename) features = fbank(wave) assert features.device.type == "cpu" - gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt") assert torch.allclose(features, gt, rtol=1e-1) + if cpu_features is None: + cpu_features = features - wave = wave.to(device) - features = fbank(wave) + features = fbank(wave.to(device)) assert features.device == device assert torch.allclose(features.cpu(), gt, rtol=1e-1) + opts = kaldifeat.FbankOptions() + opts.frame_opts.dither = 0 + opts.mel_opts.num_bins = 40 + opts.frame_opts.snip_edges = False + + test_online_fbank(opts, wave, cpu_features) + def test_fbank_chunk(): print("=====test_fbank_chunk=====") @@ -223,6 +314,16 @@ def test_pickle(): assert str(fbank.opts) == str(fbank2.opts) + opts = kaldifeat.FbankOptions() + opts.use_energy = True + opts.use_power = False + + fbank = kaldifeat.OnlineFbank(opts) + data = pickle.dumps(fbank) + fbank2 = pickle.loads(data) + + assert str(fbank.opts) == str(fbank2.opts) + if __name__ == "__main__": test_fbank_default()