Merge pull request #28 from csukuangfj/streaming-feature-extractor

Start to add streaming feature extractors.
This commit is contained in:
Fangjun Kuang 2022-04-02 20:50:31 +08:00 committed by GitHub
commit b72fc599fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 1029 additions and 61 deletions

View File

@ -45,7 +45,9 @@ jobs:
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip black flake8
python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
# See https://github.com/psf/black/issues/2964
# The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
- name: Run flake8
shell: bash

View File

@ -31,6 +31,17 @@ features = fbank(wave)
</td>
</tr>
<tr>
<td>Streaming FBANK</td>
<td><code>kaldifeat.FbankOptions</code></td>
<td><code>kaldifeat.OnlineFbank</code></td>
<td>
See <a href="./kaldifeat/python/tests/test_fbank.py">
./kaldifeat/python/tests/test_fbank.py
</a>
</td>
</tr>
<tr>
<td>MFCC</td>
<td><code>kaldifeat.MfccOptions</code></td>
@ -45,6 +56,17 @@ features = mfcc(wave)
</td>
</tr>
<tr>
<td>Streaming MFCC</td>
<td><code>kaldifeat.MfccOptions</code></td>
<td><code>kaldifeat.OnlineMfcc</code></td>
<td>
See <a href="./kaldifeat/python/tests/test_mfcc.py">
./kaldifeat/python/tests/test_mfcc.py
</a>
</td>
</tr>
<tr>
<td>PLP</td>
<td><code>kaldifeat.PlpOptions</code></td>
@ -59,6 +81,17 @@ features = plp(wave)
</td>
</tr>
<tr>
<td>Streaming PLP</td>
<td><code>kaldifeat.PlpOptions</code></td>
<td><code>kaldifeat.OnlinePlp</code></td>
<td>
See <a href="./kaldifeat/python/tests/test_plp.py">
./kaldifeat/python/tests/test_plp.py
</a>
</td>
</tr>
<tr>
<td>Spectorgram</td>
<td><code>kaldifeat.SpectrogramOptions</code></td>
@ -88,6 +121,8 @@ The following kaldi-compatible commandline tools are implemented:
(**NOTE**: We will implement other types of features, e.g., Pitch, ivector, etc, soon.)
**HINT**: It supports also streaming feature extractors for Fbank, MFCC, and Plp.
# Usage
Let us first generate a test wave using sox:

View File

@ -9,6 +9,7 @@ set(kaldifeat_srcs
feature-window.cc
matrix-functions.cc
mel-computations.cc
online-feature.cc
)
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
@ -40,6 +41,7 @@ if(kaldifeat_BUILD_TESTS)
# please sort the source files alphabetically
set(test_srcs
feature-window-test.cc
online-feature-test.cc
)
foreach(source IN LISTS test_srcs)

View File

@ -62,6 +62,10 @@ class OfflineFeatureTpl {
int32_t Dim() const { return computer_.Dim(); }
const Options &GetOptions() const { return computer_.GetOptions(); }
const FrameExtractionOptions &GetFrameOptions() const {
return GetOptions().frame_opts;
}
// Copy constructor.
OfflineFeatureTpl(const OfflineFeatureTpl<F> &) = delete;
OfflineFeatureTpl<F> &operator=(const OfflineFeatureTpl<F> &) = delete;

View File

@ -161,19 +161,20 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value) {
#if 1
return wave + rand_gauss * dither_value;
#else
// use in-place version of wave and change its to pointer type
// use in-place version of wave and change it to pointer type
wave_->add_(rand_gauss, dither_value);
#endif
}
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) {
using namespace torch::indexing; // It imports: Slice, None // NOLINT
if (preemph_coeff == 0.0f) return wave;
KALDIFEAT_ASSERT(preemph_coeff >= 0.0f && preemph_coeff <= 1.0f);
torch::Tensor ans = torch::empty_like(wave);
using torch::indexing::None;
using torch::indexing::Slice;
// right = wave[:, 1:]
torch::Tensor right = wave.index({"...", Slice(1, None, None)});
@ -188,4 +189,59 @@ torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) {
return ans;
}
torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
int32_t f, const FrameExtractionOptions &opts) {
KALDIFEAT_ASSERT(sample_offset >= 0 && wave.numel() != 0);
int32_t frame_length = opts.WindowSize();
int64_t num_samples = sample_offset + wave.numel();
int64_t start_sample = FirstSampleOfFrame(f, opts);
int64_t end_sample = start_sample + frame_length;
if (opts.snip_edges) {
KALDIFEAT_ASSERT(start_sample >= sample_offset &&
end_sample <= num_samples);
} else {
KALDIFEAT_ASSERT(sample_offset == 0 || start_sample >= sample_offset);
}
// wave_start and wave_end are start and end indexes into 'wave', for the
// piece of wave that we're trying to extract.
int32_t wave_start = static_cast<int32_t>(start_sample - sample_offset);
int32_t wave_end = wave_start + frame_length;
if (wave_start >= 0 && wave_end <= wave.numel()) {
// the normal case -- no edge effects to consider.
// return wave[wave_start:wave_end]
return wave.index({torch::indexing::Slice(wave_start, wave_end)});
} else {
torch::Tensor window = torch::empty({frame_length}, torch::kFloat);
auto p_window = window.accessor<float, 1>();
auto p_wave = wave.accessor<float, 1>();
// Deal with any end effects by reflection, if needed. This code will only
// be reached for about two frames per utterance, so we don't concern
// ourselves excessively with efficiency.
int32_t wave_dim = wave.numel();
for (int32_t s = 0; s != frame_length; ++s) {
int32_t s_in_wave = s + wave_start;
while (s_in_wave < 0 || s_in_wave >= wave_dim) {
// reflect around the beginning or end of the wave.
// e.g. -1 -> 0, -2 -> 1.
// dim -> dim - 1, dim + 1 -> dim - 2.
// the code supports repeated reflections, although this
// would only be needed in pathological cases.
if (s_in_wave < 0) {
s_in_wave = -s_in_wave - 1;
} else {
s_in_wave = 2 * wave_dim - 1 - s_in_wave;
}
}
p_window[s] = p_wave[s_in_wave];
}
return window;
}
}
} // namespace kaldifeat

View File

@ -44,7 +44,11 @@ struct FrameExtractionOptions {
bool snip_edges = true;
// bool allow_downsample = false;
// bool allow_upsample = false;
// int32_t max_feature_vectors = -1;
// Used for streaming feature extraction. It indicates the number
// of feature frames to keep in the recycling vector. -1 means to
// keep all feature frames.
int32_t max_feature_vectors = -1;
int32_t WindowShift() const {
return static_cast<int32_t>(samp_freq * 0.001f * frame_shift_ms);
@ -71,7 +75,7 @@ struct FrameExtractionOptions {
KALDIFEAT_PRINT(snip_edges);
// KALDIFEAT_PRINT(allow_downsample);
// KALDIFEAT_PRINT(allow_upsample);
// KALDIFEAT_PRINT(max_feature_vectors);
KALDIFEAT_PRINT(max_feature_vectors);
#undef KALDIFEAT_PRINT
return os.str();
}
@ -100,11 +104,11 @@ class FeatureWindowFunction {
@param [in] flush True if we are asserting that this number of samples
is 'all there is', false if we expecting more data to possibly come in. This
only makes a difference to the answer if opts.snips_edges
== false. For offline feature extraction you always want flush ==
true. In an online-decoding context, once you know (or decide)
that no more data is coming in, you'd call it with flush == true at the end
to flush out any remaining data.
only makes a difference to the answer
if opts.snips_edges== false. For offline feature extraction you always want
flush == true. In an online-decoding context, once you know (or decide) that
no more data is coming in, you'd call it with flush == true at the end to
flush out any remaining data.
*/
int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts,
bool flush = true);
@ -133,6 +137,29 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value);
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave);
/*
ExtractWindow() extracts "frame_length" samples from the given waveform.
Note: This function only extracts "frame_length" samples
from the input waveform, without any further processing.
@param [in] sample_offset If 'wave' is not the entire waveform, but
part of it to the left has been discarded, then the
number of samples prior to 'wave' that we have
already discarded. Set this to zero if you are
processing the entire waveform in one piece, or
if you get 'no matching function' compilation
errors when updating the code.
@param [in] wave The waveform
@param [in] f The frame index to be extracted, with
0 <= f < NumFrames(sample_offset + wave.numel(), opts, true)
@param [in] opts The options class to be used
@return Return a tensor containing "frame_length" samples extracted from
`wave`, without any further processing. Its shape is
(1, frame_length).
*/
torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
int32_t f, const FrameExtractionOptions &opts);
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_

View File

@ -0,0 +1,89 @@
// kaldifeat/csrc/online-feature-itf.h
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/itf/online-feature-itf.h
#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
#define KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
#include <utility>
#include <vector>
#include "torch/script.h"
namespace kaldifeat {
class OnlineFeatureInterface {
public:
virtual ~OnlineFeatureInterface() = default;
virtual int32_t Dim() const = 0; /// returns the feature dimension.
//
// Returns frame shift in seconds. Helps to estimate duration from frame
// counts.
virtual float FrameShiftInSeconds() const = 0;
/// Returns the total number of frames, since the start of the utterance, that
/// are now available. In an online-decoding context, this will likely
/// increase with time as more data becomes available.
virtual int32_t NumFramesReady() const = 0;
/// Returns true if this is the last frame. Frame indices are zero-based, so
/// the first frame is zero. IsLastFrame(-1) will return false, unless the
/// file is empty (which is a case that I'm not sure all the code will handle,
/// so be careful). This function may return false for some frame if we
/// haven't yet decided to terminate decoding, but later true if we decide to
/// terminate decoding. This function exists mainly to correctly handle end
/// effects in feature extraction, and is not a mechanism to determine how
/// many frames are in the decodable object (as it used to be, and for
/// backward compatibility, still is, in the Decodable interface).
virtual bool IsLastFrame(int32_t frame) const = 0;
/// Gets the feature vector for this frame. Before calling this for a given
/// frame, it is assumed that you called NumFramesReady() and it returned a
/// number greater than "frame". Otherwise this call will likely crash with
/// an assert failure. This function is not declared const, in case there is
/// some kind of caching going on, but most of the time it shouldn't modify
/// the class.
///
/// The returned tensor has shape (1, Dim()).
virtual torch::Tensor GetFrame(int32_t frame) = 0;
/// This is like GetFrame() but for a collection of frames. There is a
/// default implementation that just gets the frames one by one, but it
/// may be overridden for efficiency by child classes (since sometimes
/// it's more efficient to do things in a batch).
///
/// The returned tensor has shape (frames.size(), Dim()).
virtual std::vector<torch::Tensor> GetFrames(
const std::vector<int32_t> &frames) {
std::vector<torch::Tensor> features;
features.reserve(frames.size());
for (auto i : frames) {
torch::Tensor f = GetFrame(i);
features.push_back(std::move(f));
}
return features;
#if 0
return torch::cat(features, /*dim*/ 0);
#endif
}
/// This would be called from the application, when you get more wave data.
/// Note: the sampling_rate is typically only provided so the code can assert
/// that it matches the sampling rate expected in the options.
virtual void AcceptWaveform(float sampling_rate,
const torch::Tensor &waveform) = 0;
/// InputFinished() tells the class you won't be providing any
/// more waveform. This will help flush out the last few frames
/// of delta or LDA features (it will typically affect the return value
/// of IsLastFrame.
virtual void InputFinished() = 0;
};
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_

View File

@ -0,0 +1,49 @@
// kaldifeat/csrc/online-feature-test.h
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
#include "kaldifeat/csrc/online-feature.h"
#include "gtest/gtest.h"
namespace kaldifeat {
TEST(RecyclingVector, TestUnlimited) {
RecyclingVector v(-1);
constexpr int32_t N = 100;
for (int32_t i = 0; i != N; ++i) {
torch::Tensor t = torch::tensor({i, i + 1, i + 2});
v.PushBack(t);
}
ASSERT_EQ(v.Size(), N);
for (int32_t i = 0; i != N; ++i) {
torch::Tensor t = v.At(i);
torch::Tensor expected = torch::tensor({i, i + 1, i + 2});
EXPECT_TRUE(t.equal(expected));
}
}
TEST(RecyclingVector, Testlimited) {
constexpr int32_t K = 3;
constexpr int32_t N = 10;
RecyclingVector v(K);
for (int32_t i = 0; i != N; ++i) {
torch::Tensor t = torch::tensor({i, i + 1, i + 2});
v.PushBack(t);
}
ASSERT_EQ(v.Size(), N);
for (int32_t i = 0; i < N - K; ++i) {
ASSERT_DEATH(v.At(i), "");
}
for (int32_t i = N - K; i != N; ++i) {
torch::Tensor t = v.At(i);
torch::Tensor expected = torch::tensor({i, i + 1, i + 2});
EXPECT_TRUE(t.equal(expected));
}
}
} // namespace kaldifeat

View File

@ -0,0 +1,133 @@
// kaldifeat/csrc/online-feature.cc
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/online-feature.cc
#include "kaldifeat/csrc/online-feature.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/log.h"
namespace kaldifeat {
RecyclingVector::RecyclingVector(int32_t items_to_hold)
: items_to_hold_(items_to_hold == 0 ? -1 : items_to_hold),
first_available_index_(0) {}
torch::Tensor RecyclingVector::At(int32_t index) const {
if (index < first_available_index_) {
KALDIFEAT_ERR << "Attempted to retrieve feature vector that was "
"already removed by the RecyclingVector (index = "
<< index << "; "
<< "first_available_index = " << first_available_index_
<< "; "
<< "size = " << Size() << ")";
}
// 'at' does size checking.
return items_.at(index - first_available_index_);
}
void RecyclingVector::PushBack(torch::Tensor item) {
// Note: -1 is a larger number when treated as unsigned
if (items_.size() == static_cast<size_t>(items_to_hold_)) {
items_.pop_front();
++first_available_index_;
}
items_.push_back(item);
}
int32_t RecyclingVector::Size() const {
return first_available_index_ + static_cast<int32_t>(items_.size());
}
template <class C>
OnlineGenericBaseFeature<C>::OnlineGenericBaseFeature(
const typename C::Options &opts)
: computer_(opts),
window_function_(opts.frame_opts, opts.device),
features_(opts.frame_opts.max_feature_vectors),
input_finished_(false),
waveform_offset_(0) {}
template <class C>
void OnlineGenericBaseFeature<C>::AcceptWaveform(
float sampling_rate, const torch::Tensor &original_waveform) {
if (original_waveform.numel() == 0) return; // Nothing to do.
KALDIFEAT_ASSERT(original_waveform.dim() == 1);
KALDIFEAT_ASSERT(sampling_rate == computer_.GetFrameOptions().samp_freq);
if (input_finished_)
KALDIFEAT_ERR << "AcceptWaveform called after InputFinished() was called.";
if (waveform_remainder_.numel() == 0) {
waveform_remainder_ = original_waveform;
} else {
waveform_remainder_ =
torch::cat({waveform_remainder_, original_waveform}, /*dim*/ 0);
}
ComputeFeatures();
}
template <class C>
void OnlineGenericBaseFeature<C>::InputFinished() {
input_finished_ = true;
ComputeFeatures();
}
template <class C>
void OnlineGenericBaseFeature<C>::ComputeFeatures() {
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
int64_t num_samples_total = waveform_offset_ + waveform_remainder_.numel();
int32_t num_frames_old = features_.Size();
int32_t num_frames_new =
NumFrames(num_samples_total, frame_opts, input_finished_);
KALDIFEAT_ASSERT(num_frames_new >= num_frames_old);
// note: this online feature-extraction code does not support VTLN.
float vtln_warp = 1.0;
for (int32_t frame = num_frames_old; frame < num_frames_new; ++frame) {
torch::Tensor window =
ExtractWindow(waveform_offset_, waveform_remainder_, frame, frame_opts);
// TODO(fangjun): We can compute all feature frames at once
torch::Tensor this_feature =
computer_.ComputeFeatures(window.unsqueeze(0), vtln_warp);
features_.PushBack(this_feature);
}
// OK, we will now discard any portion of the signal that will not be
// necessary to compute frames in the future.
int64_t first_sample_of_next_frame =
FirstSampleOfFrame(num_frames_new, frame_opts);
int32_t samples_to_discard = first_sample_of_next_frame - waveform_offset_;
if (samples_to_discard > 0) {
// discard the leftmost part of the waveform that we no longer need.
int32_t new_num_samples = waveform_remainder_.numel() - samples_to_discard;
if (new_num_samples <= 0) {
// odd, but we'll try to handle it.
waveform_offset_ += waveform_remainder_.numel();
waveform_remainder_.resize_({0});
} else {
using torch::indexing::None;
using torch::indexing::Slice;
waveform_remainder_ =
waveform_remainder_.index({Slice(samples_to_discard, None)});
waveform_offset_ += samples_to_discard;
}
}
}
// instantiate the templates defined here for MFCC, PLP and filterbank classes.
template class OnlineGenericBaseFeature<Mfcc>;
template class OnlineGenericBaseFeature<Plp>;
template class OnlineGenericBaseFeature<Fbank>;
} // namespace kaldifeat

View File

@ -0,0 +1,127 @@
// kaldifeat/csrc/online-feature.h
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/online-feature.h
#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_H_
#define KALDIFEAT_CSRC_ONLINE_FEATURE_H_
#include <deque>
#include "kaldifeat/csrc/feature-fbank.h"
#include "kaldifeat/csrc/feature-mfcc.h"
#include "kaldifeat/csrc/feature-plp.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/online-feature-itf.h"
namespace kaldifeat {
/// This class serves as a storage for feature vectors with an option to limit
/// the memory usage by removing old elements. The deleted frames indices are
/// "remembered" so that regardless of the MAX_ITEMS setting, the user always
/// provides the indices as if no deletion was being performed.
/// This is useful when processing very long recordings which would otherwise
/// cause the memory to eventually blow up when the features are not being
/// removed.
class RecyclingVector {
public:
/// By default it does not remove any elements.
explicit RecyclingVector(int32_t items_to_hold = -1);
~RecyclingVector() = default;
RecyclingVector(const RecyclingVector &) = delete;
RecyclingVector &operator=(const RecyclingVector &) = delete;
torch::Tensor At(int32_t index) const;
void PushBack(torch::Tensor item);
/// This method returns the size as if no "recycling" had happened,
/// i.e. equivalent to the number of times the PushBack method has been
/// called.
int32_t Size() const;
private:
std::deque<torch::Tensor> items_;
int32_t items_to_hold_;
int32_t first_available_index_;
};
/// This is a templated class for online feature extraction;
/// it's templated on a class like MfccComputer or PlpComputer
/// that does the basic feature extraction.
template <class C>
class OnlineGenericBaseFeature : public OnlineFeatureInterface {
public:
// Constructor from options class
explicit OnlineGenericBaseFeature(const typename C::Options &opts);
int32_t Dim() const override { return computer_.Dim(); }
float FrameShiftInSeconds() const override {
return computer_.GetFrameOptions().frame_shift_ms / 1000.0f;
}
int32_t NumFramesReady() const override { return features_.Size(); }
// Note: IsLastFrame() will only ever return true if you have called
// InputFinished() (and this frame is the last frame).
bool IsLastFrame(int32_t frame) const override {
return input_finished_ && frame == NumFramesReady() - 1;
}
torch::Tensor GetFrame(int32_t frame) override { return features_.At(frame); }
// This would be called from the application, when you get
// more wave data. Note: the sampling_rate is only provided so
// the code can assert that it matches the sampling rate
// expected in the options.
void AcceptWaveform(float sampling_rate,
const torch::Tensor &waveform) override;
// InputFinished() tells the class you won't be providing any
// more waveform. This will help flush out the last frame or two
// of features, in the case where snip-edges == false; it also
// affects the return value of IsLastFrame().
void InputFinished() override;
private:
// This function computes any additional feature frames that it is possible to
// compute from 'waveform_remainder_', which at this point may contain more
// than just a remainder-sized quantity (because AcceptWaveform() appends to
// waveform_remainder_ before calling this function). It adds these feature
// frames to features_, and shifts off any now-unneeded samples of input from
// waveform_remainder_ while incrementing waveform_offset_ by the same amount.
void ComputeFeatures();
C computer_; // class that does the MFCC or PLP or filterbank computation
FeatureWindowFunction window_function_;
// features_ is the Mfcc or Plp or Fbank features that we have already
// computed.
RecyclingVector features_;
// True if the user has called "InputFinished()"
bool input_finished_;
// waveform_offset_ is the number of samples of waveform that we have
// already discarded, i.e. that were prior to 'waveform_remainder_'.
int64_t waveform_offset_;
// waveform_remainder_ is a short piece of waveform that we may need to keep
// after extracting all the whole frames we can (whatever length of feature
// will be required for the next phase of computation).
// It is a 1-D tensor
torch::Tensor waveform_remainder_;
};
using OnlineMfcc = OnlineGenericBaseFeature<Mfcc>;
using OnlinePlp = OnlineGenericBaseFeature<Plp>;
using OnlineFbank = OnlineGenericBaseFeature<Fbank>;
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_H_

View File

@ -7,6 +7,7 @@ pybind11_add_module(_kaldifeat
feature-window.cc
kaldifeat.cc
mel-computations.cc
online-feature.cc
utils.cc
)
target_link_libraries(_kaldifeat PRIVATE kaldifeat_core)

View File

@ -25,6 +25,7 @@ static void PybindFrameExtractionOptions(py::module &m) {
.def_readwrite("round_to_power_of_two", &PyClass::round_to_power_of_two)
.def_readwrite("blackman_coeff", &PyClass::blackman_coeff)
.def_readwrite("snip_edges", &PyClass::snip_edges)
.def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)
.def("as_dict",
[](const PyClass &self) -> py::dict { return AsDict(self); })
.def_static("from_dict",
@ -35,8 +36,6 @@ static void PybindFrameExtractionOptions(py::module &m) {
.def_readwrite("allow_downsample",
&PyClass::allow_downsample)
.def_readwrite("allow_upsample", &PyClass::allow_upsample)
.def_readwrite("max_feature_vectors",
&PyClass::max_feature_vectors)
#endif
.def("__str__",
[](const PyClass &self) -> std::string { return self.ToString(); })

View File

@ -11,6 +11,7 @@
#include "kaldifeat/python/csrc/feature-spectrogram.h"
#include "kaldifeat/python/csrc/feature-window.h"
#include "kaldifeat/python/csrc/mel-computations.h"
#include "kaldifeat/python/csrc/online-feature.h"
#include "torch/torch.h"
namespace kaldifeat {
@ -24,6 +25,7 @@ PYBIND11_MODULE(_kaldifeat, m) {
PybindFeatureMfcc(m);
PybindFeaturePlp(m);
PybindFeatureSpectrogram(m);
PybindOnlineFeature(m);
}
} // namespace kaldifeat

View File

@ -0,0 +1,37 @@
// kaldifeat/python/csrc/online-feature.cc
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
#include "kaldifeat/python/csrc/online-feature.h"
#include <string>
#include "kaldifeat/csrc/online-feature.h"
namespace kaldifeat {
template <typename C>
void PybindOnlineFeatureTpl(py::module &m, const std::string &class_name,
const std::string &class_help_doc = "") {
using PyClass = OnlineGenericBaseFeature<C>;
using Options = typename C::Options;
py::class_<PyClass>(m, class_name.c_str(), class_help_doc.c_str())
.def(py::init<const Options &>(), py::arg("opts"))
.def_property_readonly("dim", &PyClass::Dim)
.def_property_readonly("frame_shift_in_seconds",
&PyClass::FrameShiftInSeconds)
.def_property_readonly("num_frames_ready", &PyClass::NumFramesReady)
.def("is_last_frame", &PyClass::IsLastFrame, py::arg("frame"))
.def("get_frame", &PyClass::GetFrame, py::arg("frame"))
.def("get_frames", &PyClass::GetFrames, py::arg("frames"))
.def("accept_waveform", &PyClass::AcceptWaveform,
py::arg("sampling_rate"), py::arg("waveform"))
.def("input_finished", &PyClass::InputFinished);
}
void PybindOnlineFeature(py::module &m) {
PybindOnlineFeatureTpl<Mfcc>(m, "OnlineMfcc");
PybindOnlineFeatureTpl<Fbank>(m, "OnlineFbank");
PybindOnlineFeatureTpl<Plp>(m, "OnlinePlp");
}
} // namespace kaldifeat

View File

@ -0,0 +1,16 @@
// kaldifeat/python/csrc/online-feature.h
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
#ifndef KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_
#define KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_
#include "kaldifeat/python/csrc/kaldifeat.h"
namespace kaldifeat {
void PybindOnlineFeature(py::module &m);
} // namespace kaldifeat
#endif // KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_

View File

@ -30,6 +30,7 @@ FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict) {
FROM_DICT(bool_, round_to_power_of_two);
FROM_DICT(float_, blackman_coeff);
FROM_DICT(bool_, snip_edges);
FROM_DICT(int_, max_feature_vectors);
return opts;
}
@ -47,6 +48,7 @@ py::dict AsDict(const FrameExtractionOptions &opts) {
AS_DICT(round_to_power_of_two);
AS_DICT(blackman_coeff);
AS_DICT(snip_edges);
AS_DICT(max_feature_vectors);
return dict;
}

View File

@ -8,7 +8,7 @@ from _kaldifeat import (
SpectrogramOptions,
)
from .fbank import Fbank
from .mfcc import Mfcc
from .plp import Plp
from .fbank import Fbank, OnlineFbank
from .mfcc import Mfcc, OnlineMfcc
from .plp import OnlinePlp, Plp
from .spectrogram import Spectrogram

View File

@ -4,9 +4,20 @@
import _kaldifeat
from .offline_feature import OfflineFeature
from .online_feature import OnlineFeature
class Fbank(OfflineFeature):
def __init__(self, opts: _kaldifeat.FbankOptions):
super().__init__(opts)
self.computer = _kaldifeat.Fbank(opts)
class OnlineFbank(OnlineFeature):
def __init__(self, opts: _kaldifeat.FbankOptions):
super().__init__(opts)
self.computer = _kaldifeat.OnlineFbank(opts)
def __setstate__(self, state):
self.opts = _kaldifeat.FbankOptions.from_dict(state)
self.computer = _kaldifeat.OnlineFbank(self.opts)

View File

@ -4,9 +4,20 @@
import _kaldifeat
from .offline_feature import OfflineFeature
from .online_feature import OnlineFeature
class Mfcc(OfflineFeature):
def __init__(self, opts: _kaldifeat.MfccOptions):
super().__init__(opts)
self.computer = _kaldifeat.Mfcc(opts)
class OnlineMfcc(OnlineFeature):
def __init__(self, opts: _kaldifeat.MfccOptions):
super().__init__(opts)
self.computer = _kaldifeat.OnlineMfcc(opts)
def __setstate__(self, state):
self.opts = _kaldifeat.MfccOptions.from_dict(state)
self.computer = _kaldifeat.OnlineMfcc(self.opts)

View File

@ -0,0 +1,95 @@
# Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
from typing import List
import torch
class OnlineFeature(object):
"""Offline feature is a base class of other feature computers,
e.g., Fbank, Mfcc.
This class has two fields:
(1) opts. It contains the options for the feature computer.
(2) computer. The actual feature computer. It should be
instantiated by subclasses.
Caution:
It supports only CPU at present.
"""
def __init__(self, opts):
assert opts.device.type == "cpu"
self.opts = opts
# self.computer is expected to be set by subclasses
self.computer = None
@property
def num_frames_ready(self) -> int:
"""Return the number of ready frames.
It can be updated by :method:`accept_waveform`.
Note:
If you set ``opts.frame_opts.max_feature_vectors``, then
the valid frame indexes are in the range.
``[num_frames_ready - max_feature_vectors, num_frames_ready)``
If you leave ``opts.frame_opts.max_feature_vectors`` to its default
value, then the range is ``[0, num_frames_ready)``
"""
return self.computer.num_frames_ready
def is_last_frame(self, frame: int) -> bool:
"""Return True if the given frame is the last frame."""
return self.computer.is_last_frame(frame)
def get_frame(self, frame: int) -> torch.Tensor:
"""Get the frame by its index.
Args:
frame:
The frame index. If ``opts.frame_opts.max_feature_vectors`` is
-1, then its valid values are in the range
``[0, num_frames_ready)``. Otherwise, the range is
``[num_frames_ready - max_feature_vectors, num_frames_ready)``.
Returns:
Return a 2-D tensor with shape ``(1, feature_dim)``
"""
return self.computer.get_frame(frame)
def get_frames(self, frames: List[int]) -> List[torch.Tensor]:
"""Get frames at the given frame indexes.
Args:
frames:
Frames whose indexes are in this list are returned.
Returns:
Return a list of feature frames at the given indexes.
"""
return self.computer.get_frames(frames)
def accept_waveform(
self, sampling_rate: float, waveform: torch.Tensor
) -> None:
"""Send audio samples to the extractor.
Args:
sampling_rate:
The sampling rate of the given audio samples. It has to be equal
to ``opts.frame_opts.samp_freq``.
waveform:
A 1-D tensor of shape (num_samples,). Its dtype is torch.float32
and has to be on CPU.
"""
self.computer.accept_waveform(sampling_rate, waveform)
def input_finished(self) -> None:
"""Tell the extractor that no more audio samples will be available.
After calling this function, you cannot invoke ``accept_waveform``
again.
"""
self.computer.input_finished()
def __getstate__(self):
return self.opts.as_dict()

View File

@ -4,9 +4,20 @@
import _kaldifeat
from .offline_feature import OfflineFeature
from .online_feature import OnlineFeature
class Plp(OfflineFeature):
def __init__(self, opts: _kaldifeat.PlpOptions):
super().__init__(opts)
self.computer = _kaldifeat.Plp(opts)
class OnlinePlp(OnlineFeature):
def __init__(self, opts: _kaldifeat.PlpOptions):
super().__init__(opts)
self.computer = _kaldifeat.OnlinePlp(opts)
def __setstate__(self, state):
self.opts = _kaldifeat.PlpOptions.from_dict(state)
self.computer = _kaldifeat.OnlinePlp(self.opts)

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang)
# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang)
import pickle
from pathlib import Path
@ -13,30 +13,88 @@ import kaldifeat
cur_dir = Path(__file__).resolve().parent
def test_online_fbank(
opts: kaldifeat.FbankOptions,
wave: torch.Tensor,
cpu_features: torch.Tensor,
):
"""
Args:
opts:
The options to create the online fbank extractor.
wave:
The input 1-D waveform.
cpu_features:
The groud truth features that are computed offline
"""
online_fbank = kaldifeat.OnlineFbank(opts)
num_processed_frames = 0
i = 0 # current sample index to feed
while not online_fbank.is_last_frame(num_processed_frames - 1):
while num_processed_frames < online_fbank.num_frames_ready:
# There are new frames to be processed
frame = online_fbank.get_frame(num_processed_frames)
assert torch.allclose(
frame.squeeze(0), cpu_features[num_processed_frames]
)
num_processed_frames += 1
# Simulate streaming . Send a random number of audio samples
# to the extractor
num_samples = torch.randint(300, 1000, (1,)).item()
samples = wave[i : (i + num_samples)] # noqa
i += num_samples
if len(samples) == 0:
online_fbank.input_finished()
continue
online_fbank.accept_waveform(16000, samples)
assert num_processed_frames == online_fbank.num_frames_ready
assert num_processed_frames == cpu_features.size(0)
def test_fbank_default():
print("=====test_fbank_default=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test.txt")
cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
features = fbank(wave)
assert features.device.type == "cpu"
gt = read_ark_txt(cur_dir / "test_data/test.txt")
assert torch.allclose(features, gt, rtol=1e-1)
if cpu_features is None:
cpu_features = features
wave = wave.to(device)
features = fbank(wave)
features = fbank(wave.to(device))
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
# Now for online fbank
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.frame_opts.max_feature_vectors = 100
test_online_fbank(opts, wave, cpu_features)
def test_fbank_htk():
print("=====test_fbank_htk=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.FbankOptions()
@ -46,22 +104,32 @@ def test_fbank_htk():
opts.htk_compat = True
fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
features = fbank(wave)
assert features.device.type == "cpu"
gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
assert torch.allclose(features, gt, rtol=1e-1)
if cpu_features is None:
cpu_features = features
wave = wave.to(device)
features = fbank(wave)
features = fbank(wave.to(device))
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.use_energy = True
opts.htk_compat = True
test_online_fbank(opts, wave, cpu_features)
def test_fbank_with_energy():
print("=====test_fbank_with_energy=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.FbankOptions()
@ -70,22 +138,31 @@ def test_fbank_with_energy():
opts.use_energy = True
fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
features = fbank(wave)
gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
assert torch.allclose(features, gt, rtol=1e-1)
assert features.device.type == "cpu"
if cpu_features is None:
cpu_features = features
wave = wave.to(device)
features = fbank(wave)
features = fbank(wave.to(device))
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.use_energy = True
test_online_fbank(opts, wave, cpu_features)
def test_fbank_40_bins():
print("=====test_fbank_40_bins=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.FbankOptions()
@ -94,22 +171,31 @@ def test_fbank_40_bins():
opts.mel_opts.num_bins = 40
fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
features = fbank(wave)
assert features.device.type == "cpu"
gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
assert torch.allclose(features, gt, rtol=1e-1)
if cpu_features is None:
cpu_features = features
wave = wave.to(device)
features = fbank(wave)
features = fbank(wave.to(device))
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.mel_opts.num_bins = 40
test_online_fbank(opts, wave, cpu_features)
def test_fbank_40_bins_no_snip_edges():
print("=====test_fbank_40_bins_no_snip_edges=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.FbankOptions()
@ -119,19 +205,24 @@ def test_fbank_40_bins_no_snip_edges():
opts.frame_opts.snip_edges = False
fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
features = fbank(wave)
assert features.device.type == "cpu"
gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
assert torch.allclose(features, gt, rtol=1e-1)
if cpu_features is None:
cpu_features = features
wave = wave.to(device)
features = fbank(wave)
features = fbank(wave.to(device))
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.mel_opts.num_bins = 40
opts.frame_opts.snip_edges = False
test_online_fbank(opts, wave, cpu_features)
def test_fbank_chunk():
print("=====test_fbank_chunk=====")
@ -223,6 +314,16 @@ def test_pickle():
assert str(fbank.opts) == str(fbank2.opts)
opts = kaldifeat.FbankOptions()
opts.use_energy = True
opts.use_power = False
fbank = kaldifeat.OnlineFbank(opts)
data = pickle.dumps(fbank)
fbank2 = pickle.loads(data)
assert str(fbank.opts) == str(fbank2.opts)
if __name__ == "__main__":
test_fbank_default()

View File

@ -13,24 +13,82 @@ import kaldifeat
cur_dir = Path(__file__).resolve().parent
def test_online_mfcc(
opts: kaldifeat.MfccOptions,
wave: torch.Tensor,
cpu_features: torch.Tensor,
):
"""
Args:
opts:
The options to create the online mfcc extractor.
wave:
The input 1-D waveform.
cpu_features:
The groud truth features that are computed offline
"""
online_mfcc = kaldifeat.OnlineMfcc(opts)
num_processed_frames = 0
i = 0 # current sample index to feed
while not online_mfcc.is_last_frame(num_processed_frames - 1):
while num_processed_frames < online_mfcc.num_frames_ready:
# There are new frames to be processed
frame = online_mfcc.get_frame(num_processed_frames)
assert torch.allclose(
frame.squeeze(0), cpu_features[num_processed_frames], atol=1e-3
)
num_processed_frames += 1
# Simulate streaming . Send a random number of audio samples
# to the extractor
num_samples = torch.randint(300, 1000, (1,)).item()
samples = wave[i : (i + num_samples)] # noqa
i += num_samples
if len(samples) == 0:
online_mfcc.input_finished()
continue
online_mfcc.accept_waveform(16000, samples)
assert num_processed_frames == online_mfcc.num_frames_ready
assert num_processed_frames == cpu_features.size(0)
def test_mfcc_default():
print("=====test_mfcc_default=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-mfcc.txt")
cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.MfccOptions()
opts.device = device
opts.frame_opts.dither = 0
mfcc = kaldifeat.Mfcc(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device)
features = mfcc(wave)
gt = read_ark_txt(cur_dir / "test_data/test-mfcc.txt")
features = mfcc(wave.to(device))
if device.type == "cpu":
cpu_features = features
assert torch.allclose(features.cpu(), gt, atol=1e-1)
opts = kaldifeat.MfccOptions()
opts.frame_opts.dither = 0
test_online_mfcc(opts, wave, cpu_features)
def test_mfcc_no_snip_edges():
print("=====test_mfcc_no_snip_edges=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-mfcc-no-snip-edges.txt")
cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.MfccOptions()
@ -39,13 +97,19 @@ def test_mfcc_no_snip_edges():
opts.frame_opts.snip_edges = False
mfcc = kaldifeat.Mfcc(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device)
features = mfcc(wave)
gt = read_ark_txt(cur_dir / "test_data/test-mfcc-no-snip-edges.txt")
features = mfcc(wave.to(device))
if device.type == "cpu":
cpu_features = features
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
opts = kaldifeat.MfccOptions()
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
test_online_mfcc(opts, wave, cpu_features)
def test_pickle():
for device in get_devices():
@ -60,6 +124,16 @@ def test_pickle():
assert str(mfcc.opts) == str(mfcc2.opts)
opts = kaldifeat.MfccOptions()
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
mfcc = kaldifeat.OnlineMfcc(opts)
data = pickle.dumps(mfcc)
mfcc2 = pickle.loads(data)
assert str(mfcc.opts) == str(mfcc2.opts)
if __name__ == "__main__":
test_mfcc_default()

View File

@ -13,24 +13,82 @@ import kaldifeat
cur_dir = Path(__file__).resolve().parent
def test_online_plp(
opts: kaldifeat.PlpOptions,
wave: torch.Tensor,
cpu_features: torch.Tensor,
):
"""
Args:
opts:
The options to create the online plp extractor.
wave:
The input 1-D waveform.
cpu_features:
The groud truth features that are computed offline
"""
online_plp = kaldifeat.OnlinePlp(opts)
num_processed_frames = 0
i = 0 # current sample index to feed
while not online_plp.is_last_frame(num_processed_frames - 1):
while num_processed_frames < online_plp.num_frames_ready:
# There are new frames to be processed
frame = online_plp.get_frame(num_processed_frames)
assert torch.allclose(
frame.squeeze(0), cpu_features[num_processed_frames], atol=1e-3
)
num_processed_frames += 1
# Simulate streaming . Send a random number of audio samples
# to the extractor
num_samples = torch.randint(300, 1000, (1,)).item()
samples = wave[i : (i + num_samples)] # noqa
i += num_samples
if len(samples) == 0:
online_plp.input_finished()
continue
online_plp.accept_waveform(16000, samples)
assert num_processed_frames == online_plp.num_frames_ready
assert num_processed_frames == cpu_features.size(0)
def test_plp_default():
print("=====test_plp_default=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-plp.txt")
cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.PlpOptions()
opts.frame_opts.dither = 0
opts.device = device
plp = kaldifeat.Plp(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device)
features = plp(wave)
gt = read_ark_txt(cur_dir / "test_data/test-plp.txt")
features = plp(wave.to(device))
if device.type == "cpu":
cpu_features = features
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
opts = kaldifeat.PlpOptions()
opts.frame_opts.dither = 0
test_online_plp(opts, wave, cpu_features)
def test_plp_no_snip_edges():
print("=====test_plp_no_snip_edges=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-plp-no-snip-edges.txt")
cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.PlpOptions()
@ -39,16 +97,26 @@ def test_plp_no_snip_edges():
opts.frame_opts.snip_edges = False
plp = kaldifeat.Plp(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device)
features = plp(wave)
gt = read_ark_txt(cur_dir / "test_data/test-plp-no-snip-edges.txt")
features = plp(wave.to(device))
if device.type == "cpu":
cpu_features = features
assert torch.allclose(features.cpu(), gt, atol=1e-1)
opts = kaldifeat.PlpOptions()
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
test_online_plp(opts, wave, cpu_features)
def test_plp_htk_10_ceps():
print("=====test_plp_htk_10_ceps=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-plp-htk-10-ceps.txt")
cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.PlpOptions()
@ -58,13 +126,19 @@ def test_plp_htk_10_ceps():
opts.frame_opts.dither = 0
plp = kaldifeat.Plp(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device)
features = plp(wave)
gt = read_ark_txt(cur_dir / "test_data/test-plp-htk-10-ceps.txt")
features = plp(wave.to(device))
if device.type == "cpu":
cpu_features = features
assert torch.allclose(features.cpu(), gt, atol=1e-1)
opts = kaldifeat.PlpOptions()
opts.htk_compat = True
opts.num_ceps = 10
opts.frame_opts.dither = 0
test_online_plp(opts, wave, cpu_features)
def test_pickle():
for device in get_devices():
@ -79,6 +153,16 @@ def test_pickle():
assert str(plp.opts) == str(plp2.opts)
opts = kaldifeat.PlpOptions()
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
plp = kaldifeat.OnlinePlp(opts)
data = pickle.dumps(plp)
plp2 = pickle.loads(data)
assert str(plp.opts) == str(plp2.opts)
if __name__ == "__main__":
test_plp_default()