Add OnlineFbank python APIs.

This commit is contained in:
Fangjun Kuang 2022-04-02 20:03:42 +08:00
parent 039e27dd32
commit e59d05a45a
11 changed files with 296 additions and 35 deletions

View File

@ -240,6 +240,7 @@ torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
p_window[s] = p_wave[s_in_wave]; p_window[s] = p_wave[s_in_wave];
} }
return window;
} }
} }

View File

@ -56,7 +56,8 @@ class OnlineFeatureInterface {
/// it's more efficient to do things in a batch). /// it's more efficient to do things in a batch).
/// ///
/// The returned tensor has shape (frames.size(), Dim()). /// The returned tensor has shape (frames.size(), Dim()).
virtual torch::Tensor GetFrames(const std::vector<int32_t> &frames) { virtual std::vector<torch::Tensor> GetFrames(
const std::vector<int32_t> &frames) {
std::vector<torch::Tensor> features; std::vector<torch::Tensor> features;
features.reserve(frames.size()); features.reserve(frames.size());
@ -64,8 +65,9 @@ class OnlineFeatureInterface {
torch::Tensor f = GetFrame(i); torch::Tensor f = GetFrame(i);
features.push_back(std::move(f)); features.push_back(std::move(f));
} }
return features;
return torch::cat(features, /*dim*/ 0); // return torch::cat(features, [>dim<] 0);
} }
/// This would be called from the application, when you get more wave data. /// This would be called from the application, when you get more wave data.

View File

@ -48,12 +48,7 @@ OnlineGenericBaseFeature<C>::OnlineGenericBaseFeature(
window_function_(opts.frame_opts, opts.device), window_function_(opts.frame_opts, opts.device),
features_(opts.frame_opts.max_feature_vectors), features_(opts.frame_opts.max_feature_vectors),
input_finished_(false), input_finished_(false),
waveform_offset_(0) { waveform_offset_(0) {}
// Casting to uint32_t, an unsigned type, means that -1 would be treated
// as `very large`.
KALDIFEAT_ASSERT(static_cast<uint32_t>(opts.frame_opts.max_feature_vectors) >
200);
}
template <class C> template <class C>
void OnlineGenericBaseFeature<C>::AcceptWaveform( void OnlineGenericBaseFeature<C>::AcceptWaveform(
@ -61,6 +56,7 @@ void OnlineGenericBaseFeature<C>::AcceptWaveform(
if (original_waveform.numel() == 0) return; // Nothing to do. if (original_waveform.numel() == 0) return; // Nothing to do.
KALDIFEAT_ASSERT(original_waveform.dim() == 1); KALDIFEAT_ASSERT(original_waveform.dim() == 1);
KALDIFEAT_ASSERT(sampling_rate == computer_.GetFrameOptions().samp_freq);
if (input_finished_) if (input_finished_)
KALDIFEAT_ERR << "AcceptWaveform called after InputFinished() was called."; KALDIFEAT_ERR << "AcceptWaveform called after InputFinished() was called.";

View File

@ -7,6 +7,7 @@ pybind11_add_module(_kaldifeat
feature-window.cc feature-window.cc
kaldifeat.cc kaldifeat.cc
mel-computations.cc mel-computations.cc
online-feature.cc
utils.cc utils.cc
) )
target_link_libraries(_kaldifeat PRIVATE kaldifeat_core) target_link_libraries(_kaldifeat PRIVATE kaldifeat_core)

View File

@ -11,6 +11,7 @@
#include "kaldifeat/python/csrc/feature-spectrogram.h" #include "kaldifeat/python/csrc/feature-spectrogram.h"
#include "kaldifeat/python/csrc/feature-window.h" #include "kaldifeat/python/csrc/feature-window.h"
#include "kaldifeat/python/csrc/mel-computations.h" #include "kaldifeat/python/csrc/mel-computations.h"
#include "kaldifeat/python/csrc/online-feature.h"
#include "torch/torch.h" #include "torch/torch.h"
namespace kaldifeat { namespace kaldifeat {
@ -24,6 +25,7 @@ PYBIND11_MODULE(_kaldifeat, m) {
PybindFeatureMfcc(m); PybindFeatureMfcc(m);
PybindFeaturePlp(m); PybindFeaturePlp(m);
PybindFeatureSpectrogram(m); PybindFeatureSpectrogram(m);
PybindOnlineFeature(m);
} }
} // namespace kaldifeat } // namespace kaldifeat

View File

@ -0,0 +1,36 @@
// kaldifeat/python/csrc/online-feature.cc
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
#include "kaldifeat/python/csrc/online-feature.h"
#include "kaldifeat/csrc/online-feature.h"
namespace kaldifeat {
template <typename C>
void PybindOnlineFeatureTpl(py::module &m, const std::string &class_name,
const std::string &class_help_doc = "") {
using PyClass = OnlineGenericBaseFeature<C>;
using Options = typename C::Options;
py::class_<PyClass>(m, class_name.c_str(), class_help_doc.c_str())
.def(py::init<const Options &>(), py::arg("opts"))
.def_property_readonly("dim", &PyClass::Dim)
.def_property_readonly("frame_shift_in_seconds",
&PyClass::FrameShiftInSeconds)
.def_property_readonly("num_frames_ready", &PyClass::NumFramesReady)
.def("is_last_frame", &PyClass::IsLastFrame, py::arg("frame"))
.def("get_frame", &PyClass::GetFrame, py::arg("frame"))
.def("get_frames", &PyClass::GetFrames, py::arg("frames"))
.def("accept_waveform", &PyClass::AcceptWaveform,
py::arg("sampling_rate"), py::arg("waveform"))
.def("input_finished", &PyClass::InputFinished);
}
void PybindOnlineFeature(py::module &m) {
PybindOnlineFeatureTpl<Mfcc>(m, "OnlineMfcc");
PybindOnlineFeatureTpl<Fbank>(m, "OnlineFbank");
PybindOnlineFeatureTpl<Plp>(m, "OnlinePlp");
}
} // namespace kaldifeat

View File

@ -0,0 +1,16 @@
// kaldifeat/python/csrc/online-feature.h
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
#ifndef KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_
#define KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_
#include "kaldifeat/python/csrc/kaldifeat.h"
namespace kaldifeat {
void PybindOnlineFeature(py::module &m);
} // namespace kaldifeat
#endif // KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_

View File

@ -8,7 +8,7 @@ from _kaldifeat import (
SpectrogramOptions, SpectrogramOptions,
) )
from .fbank import Fbank from .fbank import Fbank, OnlineFbank
from .mfcc import Mfcc from .mfcc import Mfcc
from .plp import Plp from .plp import Plp
from .spectrogram import Spectrogram from .spectrogram import Spectrogram

View File

@ -4,9 +4,20 @@
import _kaldifeat import _kaldifeat
from .offline_feature import OfflineFeature from .offline_feature import OfflineFeature
from .online_feature import OnlineFeature
class Fbank(OfflineFeature): class Fbank(OfflineFeature):
def __init__(self, opts: _kaldifeat.FbankOptions): def __init__(self, opts: _kaldifeat.FbankOptions):
super().__init__(opts) super().__init__(opts)
self.computer = _kaldifeat.Fbank(opts) self.computer = _kaldifeat.Fbank(opts)
class OnlineFbank(OnlineFeature):
def __init__(self, opts: _kaldifeat.FbankOptions):
super().__init__(opts)
self.computer = _kaldifeat.OnlineFbank(opts)
def __setstate__(self, state):
self.opts = _kaldifeat.FbankOptions.from_dict(state)
self.computer = _kaldifeat.Fbank(self.opts)

View File

@ -0,0 +1,95 @@
# Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
from typing import List
import torch
class OnlineFeature(object):
"""Offline feature is a base class of other feature computers,
e.g., Fbank, Mfcc.
This class has two fields:
(1) opts. It contains the options for the feature computer.
(2) computer. The actual feature computer. It should be
instantiated by subclasses.
Caution:
It supports only CPU at present.
"""
def __init__(self, opts):
assert opts.device.type == "cpu"
self.opts = opts
# self.computer is expected to be set by subclasses
self.computer = None
@property
def num_frames_ready(self) -> int:
"""Return the number of ready frames.
It can be updated by :method:`accept_waveform`.
Note:
If you set ``opts.frame_opts.max_feature_vectors``, then
the valid frame indexes are in the range.
``[num_frames_ready - max_feature_vectors, num_frames_ready)``
If you leave ``opts.frame_opts.max_feature_vectors`` to its default
value, then the range is ``[0, num_frames_ready)``
"""
return self.computer.num_frames_ready
def is_last_frame(self, frame: int) -> bool:
"""Return True if the given frame is the last frame."""
return self.computer.is_last_frame(frame)
def get_frame(self, frame: int) -> torch.Tensor:
"""Get the frame by its index.
Args:
frame:
The frame index. If ``opts.frame_opts.max_feature_vectors`` is
-1, then its valid values are in the range
``[0, num_frames_ready)``. Otherwise, the range is
``[num_frames_ready - max_feature_vectors, num_frames_ready)``.
Returns:
Return a 2-D tensor with shape ``(1, feature_dim)``
"""
return self.computer.get_frame(frame)
def get_frames(self, frames: List[int]) -> List[torch.Tensor]:
"""Get frames at the given frame indexes.
Args:
frames:
Frames whose indexes are in this list are returned.
Returns:
Return a list of feature frames at the given indexes.
"""
return self.computer.get_frames(frames)
def accept_waveform(
self, sampling_rate: float, waveform: torch.Tensor
) -> None:
"""Send audio samples to the extractor.
Args:
sampling_rate:
The sampling rate of the given audio samples. It has to be equal
to ``opts.frame_opts.samp_freq``.
waveform:
A 1-D tensor of shape (num_samples,). Its dtype is torch.float32
and has to be on CPU.
"""
self.computer.accept_waveform(sampling_rate, waveform)
def input_finished(self) -> None:
"""Tell the extractor that no more audio samples will be available.
After calling this function, you cannot invoke ``accept_waveform``
again.
"""
self.computer.input_finished()
def __getstate__(self):
return self.opts.as_dict()

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) # Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang)
import pickle import pickle
from pathlib import Path from pathlib import Path
@ -13,30 +13,88 @@ import kaldifeat
cur_dir = Path(__file__).resolve().parent cur_dir = Path(__file__).resolve().parent
def test_online_fbank(
opts: kaldifeat.FbankOptions,
wave: torch.Tensor,
cpu_features: torch.Tensor,
):
"""
Args:
opts:
The options to create the online fbank extractor.
wave:
The input 1-D waveform.
cpu_features:
The groud truth features that are computed offline
"""
online_fbank = kaldifeat.OnlineFbank(opts)
num_processed_frames = 0
i = 0 # current sample index to feed
while not online_fbank.is_last_frame(num_processed_frames - 1):
while num_processed_frames < online_fbank.num_frames_ready:
# There are new frames to be processed
frame = online_fbank.get_frame(num_processed_frames)
assert torch.allclose(
frame.squeeze(0), cpu_features[num_processed_frames]
)
num_processed_frames += 1
# Simulate streaming . Send a random number of audio samples
# to the extractor
num_samples = torch.randint(300, 1000, (1,)).item()
samples = wave[i : (i + num_samples)] # noqa
i += num_samples
if len(samples) == 0:
online_fbank.input_finished()
continue
online_fbank.accept_waveform(16000, samples)
assert num_processed_frames == online_fbank.num_frames_ready
assert num_processed_frames == cpu_features.size(0)
def test_fbank_default(): def test_fbank_default():
print("=====test_fbank_default=====") print("=====test_fbank_default=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test.txt")
cpu_features = None
for device in get_devices(): for device in get_devices():
print("device", device) print("device", device)
opts = kaldifeat.FbankOptions() opts = kaldifeat.FbankOptions()
opts.device = device opts.device = device
opts.frame_opts.dither = 0 opts.frame_opts.dither = 0
fbank = kaldifeat.Fbank(opts) fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
features = fbank(wave) features = fbank(wave)
assert features.device.type == "cpu" assert features.device.type == "cpu"
gt = read_ark_txt(cur_dir / "test_data/test.txt")
assert torch.allclose(features, gt, rtol=1e-1) assert torch.allclose(features, gt, rtol=1e-1)
if cpu_features is None:
cpu_features = features
wave = wave.to(device) features = fbank(wave.to(device))
features = fbank(wave)
assert features.device == device assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1) assert torch.allclose(features.cpu(), gt, rtol=1e-1)
# Now for online fbank
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.frame_opts.max_feature_vectors = 100
test_online_fbank(opts, wave, cpu_features)
def test_fbank_htk(): def test_fbank_htk():
print("=====test_fbank_htk=====") print("=====test_fbank_htk=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
cpu_features = None
for device in get_devices(): for device in get_devices():
print("device", device) print("device", device)
opts = kaldifeat.FbankOptions() opts = kaldifeat.FbankOptions()
@ -46,22 +104,32 @@ def test_fbank_htk():
opts.htk_compat = True opts.htk_compat = True
fbank = kaldifeat.Fbank(opts) fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
features = fbank(wave) features = fbank(wave)
assert features.device.type == "cpu" assert features.device.type == "cpu"
gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
assert torch.allclose(features, gt, rtol=1e-1) assert torch.allclose(features, gt, rtol=1e-1)
if cpu_features is None:
cpu_features = features
wave = wave.to(device) features = fbank(wave.to(device))
features = fbank(wave)
assert features.device == device assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1) assert torch.allclose(features.cpu(), gt, rtol=1e-1)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.use_energy = True
opts.htk_compat = True
test_online_fbank(opts, wave, cpu_features)
def test_fbank_with_energy(): def test_fbank_with_energy():
print("=====test_fbank_with_energy=====") print("=====test_fbank_with_energy=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
cpu_features = None
for device in get_devices(): for device in get_devices():
print("device", device) print("device", device)
opts = kaldifeat.FbankOptions() opts = kaldifeat.FbankOptions()
@ -70,22 +138,31 @@ def test_fbank_with_energy():
opts.use_energy = True opts.use_energy = True
fbank = kaldifeat.Fbank(opts) fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
features = fbank(wave) features = fbank(wave)
gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
assert torch.allclose(features, gt, rtol=1e-1) assert torch.allclose(features, gt, rtol=1e-1)
assert features.device.type == "cpu" assert features.device.type == "cpu"
if cpu_features is None:
cpu_features = features
wave = wave.to(device) features = fbank(wave.to(device))
features = fbank(wave)
assert features.device == device assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1) assert torch.allclose(features.cpu(), gt, rtol=1e-1)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.use_energy = True
test_online_fbank(opts, wave, cpu_features)
def test_fbank_40_bins(): def test_fbank_40_bins():
print("=====test_fbank_40_bins=====") print("=====test_fbank_40_bins=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
cpu_features = None
for device in get_devices(): for device in get_devices():
print("device", device) print("device", device)
opts = kaldifeat.FbankOptions() opts = kaldifeat.FbankOptions()
@ -94,22 +171,31 @@ def test_fbank_40_bins():
opts.mel_opts.num_bins = 40 opts.mel_opts.num_bins = 40
fbank = kaldifeat.Fbank(opts) fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
features = fbank(wave) features = fbank(wave)
assert features.device.type == "cpu" assert features.device.type == "cpu"
gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
assert torch.allclose(features, gt, rtol=1e-1) assert torch.allclose(features, gt, rtol=1e-1)
if cpu_features is None:
cpu_features = features
wave = wave.to(device) features = fbank(wave.to(device))
features = fbank(wave)
assert features.device == device assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1) assert torch.allclose(features.cpu(), gt, rtol=1e-1)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.mel_opts.num_bins = 40
test_online_fbank(opts, wave, cpu_features)
def test_fbank_40_bins_no_snip_edges(): def test_fbank_40_bins_no_snip_edges():
print("=====test_fbank_40_bins_no_snip_edges=====") print("=====test_fbank_40_bins_no_snip_edges=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
cpu_features = None
for device in get_devices(): for device in get_devices():
print("device", device) print("device", device)
opts = kaldifeat.FbankOptions() opts = kaldifeat.FbankOptions()
@ -119,19 +205,24 @@ def test_fbank_40_bins_no_snip_edges():
opts.frame_opts.snip_edges = False opts.frame_opts.snip_edges = False
fbank = kaldifeat.Fbank(opts) fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
features = fbank(wave) features = fbank(wave)
assert features.device.type == "cpu" assert features.device.type == "cpu"
gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
assert torch.allclose(features, gt, rtol=1e-1) assert torch.allclose(features, gt, rtol=1e-1)
if cpu_features is None:
cpu_features = features
wave = wave.to(device) features = fbank(wave.to(device))
features = fbank(wave)
assert features.device == device assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1) assert torch.allclose(features.cpu(), gt, rtol=1e-1)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.mel_opts.num_bins = 40
opts.frame_opts.snip_edges = False
test_online_fbank(opts, wave, cpu_features)
def test_fbank_chunk(): def test_fbank_chunk():
print("=====test_fbank_chunk=====") print("=====test_fbank_chunk=====")
@ -223,6 +314,16 @@ def test_pickle():
assert str(fbank.opts) == str(fbank2.opts) assert str(fbank.opts) == str(fbank2.opts)
opts = kaldifeat.FbankOptions()
opts.use_energy = True
opts.use_power = False
fbank = kaldifeat.OnlineFbank(opts)
data = pickle.dumps(fbank)
fbank2 = pickle.loads(data)
assert str(fbank.opts) == str(fbank2.opts)
if __name__ == "__main__": if __name__ == "__main__":
test_fbank_default() test_fbank_default()