diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml
index e38f9d0..462a2db 100644
--- a/.github/workflows/style_check.yml
+++ b/.github/workflows/style_check.yml
@@ -45,7 +45,9 @@ jobs:
- name: Install Python dependencies
run: |
- python3 -m pip install --upgrade pip black flake8
+ python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
+ # See https://github.com/psf/black/issues/2964
+ # The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
- name: Run flake8
shell: bash
diff --git a/README.md b/README.md
index 841af4d..e4effa5 100644
--- a/README.md
+++ b/README.md
@@ -31,6 +31,17 @@ features = fbank(wave)
+
+Streaming FBANK |
+kaldifeat.FbankOptions |
+kaldifeat.OnlineFbank |
+
+See
+./kaldifeat/python/tests/test_fbank.py
+
+ |
+
+
MFCC |
kaldifeat.MfccOptions |
@@ -45,6 +56,17 @@ features = mfcc(wave)
+
+Streaming MFCC |
+kaldifeat.MfccOptions |
+kaldifeat.OnlineMfcc |
+
+See
+./kaldifeat/python/tests/test_mfcc.py
+
+ |
+
+
PLP |
kaldifeat.PlpOptions |
@@ -59,6 +81,17 @@ features = plp(wave)
+
+Streaming PLP |
+kaldifeat.PlpOptions |
+kaldifeat.OnlinePlp |
+
+See
+./kaldifeat/python/tests/test_plp.py
+
+ |
+
+
Spectorgram |
kaldifeat.SpectrogramOptions |
@@ -88,6 +121,8 @@ The following kaldi-compatible commandline tools are implemented:
(**NOTE**: We will implement other types of features, e.g., Pitch, ivector, etc, soon.)
+**HINT**: It supports also streaming feature extractors for Fbank, MFCC, and Plp.
+
# Usage
Let us first generate a test wave using sox:
diff --git a/kaldifeat/csrc/CMakeLists.txt b/kaldifeat/csrc/CMakeLists.txt
index 8dce57c..7e6f943 100644
--- a/kaldifeat/csrc/CMakeLists.txt
+++ b/kaldifeat/csrc/CMakeLists.txt
@@ -9,6 +9,7 @@ set(kaldifeat_srcs
feature-window.cc
matrix-functions.cc
mel-computations.cc
+ online-feature.cc
)
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
@@ -40,6 +41,7 @@ if(kaldifeat_BUILD_TESTS)
# please sort the source files alphabetically
set(test_srcs
feature-window-test.cc
+ online-feature-test.cc
)
foreach(source IN LISTS test_srcs)
diff --git a/kaldifeat/csrc/feature-common.h b/kaldifeat/csrc/feature-common.h
index 5710c22..24e7ec1 100644
--- a/kaldifeat/csrc/feature-common.h
+++ b/kaldifeat/csrc/feature-common.h
@@ -62,6 +62,10 @@ class OfflineFeatureTpl {
int32_t Dim() const { return computer_.Dim(); }
const Options &GetOptions() const { return computer_.GetOptions(); }
+ const FrameExtractionOptions &GetFrameOptions() const {
+ return GetOptions().frame_opts;
+ }
+
// Copy constructor.
OfflineFeatureTpl(const OfflineFeatureTpl &) = delete;
OfflineFeatureTpl &operator=(const OfflineFeatureTpl &) = delete;
diff --git a/kaldifeat/csrc/feature-window.cc b/kaldifeat/csrc/feature-window.cc
index 4ec5ac2..6880f7e 100644
--- a/kaldifeat/csrc/feature-window.cc
+++ b/kaldifeat/csrc/feature-window.cc
@@ -161,19 +161,20 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value) {
#if 1
return wave + rand_gauss * dither_value;
#else
- // use in-place version of wave and change its to pointer type
+ // use in-place version of wave and change it to pointer type
wave_->add_(rand_gauss, dither_value);
#endif
}
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) {
- using namespace torch::indexing; // It imports: Slice, None // NOLINT
if (preemph_coeff == 0.0f) return wave;
KALDIFEAT_ASSERT(preemph_coeff >= 0.0f && preemph_coeff <= 1.0f);
torch::Tensor ans = torch::empty_like(wave);
+ using torch::indexing::None;
+ using torch::indexing::Slice;
// right = wave[:, 1:]
torch::Tensor right = wave.index({"...", Slice(1, None, None)});
@@ -188,4 +189,59 @@ torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) {
return ans;
}
+torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
+ int32_t f, const FrameExtractionOptions &opts) {
+ KALDIFEAT_ASSERT(sample_offset >= 0 && wave.numel() != 0);
+
+ int32_t frame_length = opts.WindowSize();
+ int64_t num_samples = sample_offset + wave.numel();
+ int64_t start_sample = FirstSampleOfFrame(f, opts);
+ int64_t end_sample = start_sample + frame_length;
+
+ if (opts.snip_edges) {
+ KALDIFEAT_ASSERT(start_sample >= sample_offset &&
+ end_sample <= num_samples);
+ } else {
+ KALDIFEAT_ASSERT(sample_offset == 0 || start_sample >= sample_offset);
+ }
+
+ // wave_start and wave_end are start and end indexes into 'wave', for the
+ // piece of wave that we're trying to extract.
+ int32_t wave_start = static_cast(start_sample - sample_offset);
+ int32_t wave_end = wave_start + frame_length;
+
+ if (wave_start >= 0 && wave_end <= wave.numel()) {
+ // the normal case -- no edge effects to consider.
+ // return wave[wave_start:wave_end]
+ return wave.index({torch::indexing::Slice(wave_start, wave_end)});
+ } else {
+ torch::Tensor window = torch::empty({frame_length}, torch::kFloat);
+ auto p_window = window.accessor();
+ auto p_wave = wave.accessor();
+
+ // Deal with any end effects by reflection, if needed. This code will only
+ // be reached for about two frames per utterance, so we don't concern
+ // ourselves excessively with efficiency.
+ int32_t wave_dim = wave.numel();
+ for (int32_t s = 0; s != frame_length; ++s) {
+ int32_t s_in_wave = s + wave_start;
+ while (s_in_wave < 0 || s_in_wave >= wave_dim) {
+ // reflect around the beginning or end of the wave.
+ // e.g. -1 -> 0, -2 -> 1.
+ // dim -> dim - 1, dim + 1 -> dim - 2.
+ // the code supports repeated reflections, although this
+ // would only be needed in pathological cases.
+ if (s_in_wave < 0) {
+ s_in_wave = -s_in_wave - 1;
+ } else {
+ s_in_wave = 2 * wave_dim - 1 - s_in_wave;
+ }
+ }
+
+ p_window[s] = p_wave[s_in_wave];
+ }
+ return window;
+ }
+}
+
} // namespace kaldifeat
diff --git a/kaldifeat/csrc/feature-window.h b/kaldifeat/csrc/feature-window.h
index 255c274..26d743a 100644
--- a/kaldifeat/csrc/feature-window.h
+++ b/kaldifeat/csrc/feature-window.h
@@ -44,7 +44,11 @@ struct FrameExtractionOptions {
bool snip_edges = true;
// bool allow_downsample = false;
// bool allow_upsample = false;
- // int32_t max_feature_vectors = -1;
+
+ // Used for streaming feature extraction. It indicates the number
+ // of feature frames to keep in the recycling vector. -1 means to
+ // keep all feature frames.
+ int32_t max_feature_vectors = -1;
int32_t WindowShift() const {
return static_cast(samp_freq * 0.001f * frame_shift_ms);
@@ -71,7 +75,7 @@ struct FrameExtractionOptions {
KALDIFEAT_PRINT(snip_edges);
// KALDIFEAT_PRINT(allow_downsample);
// KALDIFEAT_PRINT(allow_upsample);
- // KALDIFEAT_PRINT(max_feature_vectors);
+ KALDIFEAT_PRINT(max_feature_vectors);
#undef KALDIFEAT_PRINT
return os.str();
}
@@ -100,11 +104,11 @@ class FeatureWindowFunction {
@param [in] flush True if we are asserting that this number of samples
is 'all there is', false if we expecting more data to possibly come in. This
- only makes a difference to the answer if opts.snips_edges
- == false. For offline feature extraction you always want flush ==
- true. In an online-decoding context, once you know (or decide)
- that no more data is coming in, you'd call it with flush == true at the end
- to flush out any remaining data.
+ only makes a difference to the answer
+ if opts.snips_edges== false. For offline feature extraction you always want
+ flush == true. In an online-decoding context, once you know (or decide) that
+ no more data is coming in, you'd call it with flush == true at the end to
+ flush out any remaining data.
*/
int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts,
bool flush = true);
@@ -133,6 +137,29 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value);
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave);
+/*
+ ExtractWindow() extracts "frame_length" samples from the given waveform.
+ Note: This function only extracts "frame_length" samples
+ from the input waveform, without any further processing.
+
+ @param [in] sample_offset If 'wave' is not the entire waveform, but
+ part of it to the left has been discarded, then the
+ number of samples prior to 'wave' that we have
+ already discarded. Set this to zero if you are
+ processing the entire waveform in one piece, or
+ if you get 'no matching function' compilation
+ errors when updating the code.
+ @param [in] wave The waveform
+ @param [in] f The frame index to be extracted, with
+ 0 <= f < NumFrames(sample_offset + wave.numel(), opts, true)
+ @param [in] opts The options class to be used
+ @return Return a tensor containing "frame_length" samples extracted from
+ `wave`, without any further processing. Its shape is
+ (1, frame_length).
+*/
+torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
+ int32_t f, const FrameExtractionOptions &opts);
+
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_
diff --git a/kaldifeat/csrc/online-feature-itf.h b/kaldifeat/csrc/online-feature-itf.h
new file mode 100644
index 0000000..835e182
--- /dev/null
+++ b/kaldifeat/csrc/online-feature-itf.h
@@ -0,0 +1,89 @@
+// kaldifeat/csrc/online-feature-itf.h
+//
+// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
+
+// This file is copied/modified from kaldi/src/itf/online-feature-itf.h
+
+#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
+#define KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
+
+#include
+#include
+
+#include "torch/script.h"
+
+namespace kaldifeat {
+
+class OnlineFeatureInterface {
+ public:
+ virtual ~OnlineFeatureInterface() = default;
+
+ virtual int32_t Dim() const = 0; /// returns the feature dimension.
+ //
+ // Returns frame shift in seconds. Helps to estimate duration from frame
+ // counts.
+ virtual float FrameShiftInSeconds() const = 0;
+
+ /// Returns the total number of frames, since the start of the utterance, that
+ /// are now available. In an online-decoding context, this will likely
+ /// increase with time as more data becomes available.
+ virtual int32_t NumFramesReady() const = 0;
+
+ /// Returns true if this is the last frame. Frame indices are zero-based, so
+ /// the first frame is zero. IsLastFrame(-1) will return false, unless the
+ /// file is empty (which is a case that I'm not sure all the code will handle,
+ /// so be careful). This function may return false for some frame if we
+ /// haven't yet decided to terminate decoding, but later true if we decide to
+ /// terminate decoding. This function exists mainly to correctly handle end
+ /// effects in feature extraction, and is not a mechanism to determine how
+ /// many frames are in the decodable object (as it used to be, and for
+ /// backward compatibility, still is, in the Decodable interface).
+ virtual bool IsLastFrame(int32_t frame) const = 0;
+
+ /// Gets the feature vector for this frame. Before calling this for a given
+ /// frame, it is assumed that you called NumFramesReady() and it returned a
+ /// number greater than "frame". Otherwise this call will likely crash with
+ /// an assert failure. This function is not declared const, in case there is
+ /// some kind of caching going on, but most of the time it shouldn't modify
+ /// the class.
+ ///
+ /// The returned tensor has shape (1, Dim()).
+ virtual torch::Tensor GetFrame(int32_t frame) = 0;
+
+ /// This is like GetFrame() but for a collection of frames. There is a
+ /// default implementation that just gets the frames one by one, but it
+ /// may be overridden for efficiency by child classes (since sometimes
+ /// it's more efficient to do things in a batch).
+ ///
+ /// The returned tensor has shape (frames.size(), Dim()).
+ virtual std::vector GetFrames(
+ const std::vector &frames) {
+ std::vector features;
+ features.reserve(frames.size());
+
+ for (auto i : frames) {
+ torch::Tensor f = GetFrame(i);
+ features.push_back(std::move(f));
+ }
+ return features;
+#if 0
+ return torch::cat(features, /*dim*/ 0);
+#endif
+ }
+
+ /// This would be called from the application, when you get more wave data.
+ /// Note: the sampling_rate is typically only provided so the code can assert
+ /// that it matches the sampling rate expected in the options.
+ virtual void AcceptWaveform(float sampling_rate,
+ const torch::Tensor &waveform) = 0;
+
+ /// InputFinished() tells the class you won't be providing any
+ /// more waveform. This will help flush out the last few frames
+ /// of delta or LDA features (it will typically affect the return value
+ /// of IsLastFrame.
+ virtual void InputFinished() = 0;
+};
+
+} // namespace kaldifeat
+
+#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
diff --git a/kaldifeat/csrc/online-feature-test.cc b/kaldifeat/csrc/online-feature-test.cc
new file mode 100644
index 0000000..786c1c1
--- /dev/null
+++ b/kaldifeat/csrc/online-feature-test.cc
@@ -0,0 +1,49 @@
+// kaldifeat/csrc/online-feature-test.h
+//
+// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
+
+#include "kaldifeat/csrc/online-feature.h"
+
+#include "gtest/gtest.h"
+
+namespace kaldifeat {
+
+TEST(RecyclingVector, TestUnlimited) {
+ RecyclingVector v(-1);
+ constexpr int32_t N = 100;
+ for (int32_t i = 0; i != N; ++i) {
+ torch::Tensor t = torch::tensor({i, i + 1, i + 2});
+ v.PushBack(t);
+ }
+ ASSERT_EQ(v.Size(), N);
+
+ for (int32_t i = 0; i != N; ++i) {
+ torch::Tensor t = v.At(i);
+ torch::Tensor expected = torch::tensor({i, i + 1, i + 2});
+ EXPECT_TRUE(t.equal(expected));
+ }
+}
+
+TEST(RecyclingVector, Testlimited) {
+ constexpr int32_t K = 3;
+ constexpr int32_t N = 10;
+ RecyclingVector v(K);
+ for (int32_t i = 0; i != N; ++i) {
+ torch::Tensor t = torch::tensor({i, i + 1, i + 2});
+ v.PushBack(t);
+ }
+
+ ASSERT_EQ(v.Size(), N);
+
+ for (int32_t i = 0; i < N - K; ++i) {
+ ASSERT_DEATH(v.At(i), "");
+ }
+
+ for (int32_t i = N - K; i != N; ++i) {
+ torch::Tensor t = v.At(i);
+ torch::Tensor expected = torch::tensor({i, i + 1, i + 2});
+ EXPECT_TRUE(t.equal(expected));
+ }
+}
+
+} // namespace kaldifeat
diff --git a/kaldifeat/csrc/online-feature.cc b/kaldifeat/csrc/online-feature.cc
new file mode 100644
index 0000000..43fc1b1
--- /dev/null
+++ b/kaldifeat/csrc/online-feature.cc
@@ -0,0 +1,133 @@
+// kaldifeat/csrc/online-feature.cc
+//
+// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
+
+// This file is copied/modified from kaldi/src/feat/online-feature.cc
+
+#include "kaldifeat/csrc/online-feature.h"
+
+#include "kaldifeat/csrc/feature-window.h"
+#include "kaldifeat/csrc/log.h"
+
+namespace kaldifeat {
+
+RecyclingVector::RecyclingVector(int32_t items_to_hold)
+ : items_to_hold_(items_to_hold == 0 ? -1 : items_to_hold),
+ first_available_index_(0) {}
+
+torch::Tensor RecyclingVector::At(int32_t index) const {
+ if (index < first_available_index_) {
+ KALDIFEAT_ERR << "Attempted to retrieve feature vector that was "
+ "already removed by the RecyclingVector (index = "
+ << index << "; "
+ << "first_available_index = " << first_available_index_
+ << "; "
+ << "size = " << Size() << ")";
+ }
+ // 'at' does size checking.
+ return items_.at(index - first_available_index_);
+}
+
+void RecyclingVector::PushBack(torch::Tensor item) {
+ // Note: -1 is a larger number when treated as unsigned
+ if (items_.size() == static_cast(items_to_hold_)) {
+ items_.pop_front();
+ ++first_available_index_;
+ }
+ items_.push_back(item);
+}
+
+int32_t RecyclingVector::Size() const {
+ return first_available_index_ + static_cast(items_.size());
+}
+
+template
+OnlineGenericBaseFeature::OnlineGenericBaseFeature(
+ const typename C::Options &opts)
+ : computer_(opts),
+ window_function_(opts.frame_opts, opts.device),
+ features_(opts.frame_opts.max_feature_vectors),
+ input_finished_(false),
+ waveform_offset_(0) {}
+
+template
+void OnlineGenericBaseFeature::AcceptWaveform(
+ float sampling_rate, const torch::Tensor &original_waveform) {
+ if (original_waveform.numel() == 0) return; // Nothing to do.
+
+ KALDIFEAT_ASSERT(original_waveform.dim() == 1);
+ KALDIFEAT_ASSERT(sampling_rate == computer_.GetFrameOptions().samp_freq);
+
+ if (input_finished_)
+ KALDIFEAT_ERR << "AcceptWaveform called after InputFinished() was called.";
+
+ if (waveform_remainder_.numel() == 0) {
+ waveform_remainder_ = original_waveform;
+ } else {
+ waveform_remainder_ =
+ torch::cat({waveform_remainder_, original_waveform}, /*dim*/ 0);
+ }
+
+ ComputeFeatures();
+}
+
+template
+void OnlineGenericBaseFeature::InputFinished() {
+ input_finished_ = true;
+ ComputeFeatures();
+}
+
+template
+void OnlineGenericBaseFeature::ComputeFeatures() {
+ const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
+
+ int64_t num_samples_total = waveform_offset_ + waveform_remainder_.numel();
+ int32_t num_frames_old = features_.Size();
+ int32_t num_frames_new =
+ NumFrames(num_samples_total, frame_opts, input_finished_);
+
+ KALDIFEAT_ASSERT(num_frames_new >= num_frames_old);
+
+ // note: this online feature-extraction code does not support VTLN.
+ float vtln_warp = 1.0;
+
+ for (int32_t frame = num_frames_old; frame < num_frames_new; ++frame) {
+ torch::Tensor window =
+ ExtractWindow(waveform_offset_, waveform_remainder_, frame, frame_opts);
+
+ // TODO(fangjun): We can compute all feature frames at once
+ torch::Tensor this_feature =
+ computer_.ComputeFeatures(window.unsqueeze(0), vtln_warp);
+ features_.PushBack(this_feature);
+ }
+
+ // OK, we will now discard any portion of the signal that will not be
+ // necessary to compute frames in the future.
+ int64_t first_sample_of_next_frame =
+ FirstSampleOfFrame(num_frames_new, frame_opts);
+ int32_t samples_to_discard = first_sample_of_next_frame - waveform_offset_;
+ if (samples_to_discard > 0) {
+ // discard the leftmost part of the waveform that we no longer need.
+ int32_t new_num_samples = waveform_remainder_.numel() - samples_to_discard;
+ if (new_num_samples <= 0) {
+ // odd, but we'll try to handle it.
+ waveform_offset_ += waveform_remainder_.numel();
+ waveform_remainder_.resize_({0});
+ } else {
+ using torch::indexing::None;
+ using torch::indexing::Slice;
+
+ waveform_remainder_ =
+ waveform_remainder_.index({Slice(samples_to_discard, None)});
+
+ waveform_offset_ += samples_to_discard;
+ }
+ }
+}
+
+// instantiate the templates defined here for MFCC, PLP and filterbank classes.
+template class OnlineGenericBaseFeature;
+template class OnlineGenericBaseFeature;
+template class OnlineGenericBaseFeature;
+
+} // namespace kaldifeat
diff --git a/kaldifeat/csrc/online-feature.h b/kaldifeat/csrc/online-feature.h
new file mode 100644
index 0000000..f234b5c
--- /dev/null
+++ b/kaldifeat/csrc/online-feature.h
@@ -0,0 +1,127 @@
+// kaldifeat/csrc/online-feature.h
+//
+// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
+
+// This file is copied/modified from kaldi/src/feat/online-feature.h
+
+#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_H_
+#define KALDIFEAT_CSRC_ONLINE_FEATURE_H_
+
+#include
+
+#include "kaldifeat/csrc/feature-fbank.h"
+#include "kaldifeat/csrc/feature-mfcc.h"
+#include "kaldifeat/csrc/feature-plp.h"
+#include "kaldifeat/csrc/feature-window.h"
+#include "kaldifeat/csrc/online-feature-itf.h"
+
+namespace kaldifeat {
+
+/// This class serves as a storage for feature vectors with an option to limit
+/// the memory usage by removing old elements. The deleted frames indices are
+/// "remembered" so that regardless of the MAX_ITEMS setting, the user always
+/// provides the indices as if no deletion was being performed.
+/// This is useful when processing very long recordings which would otherwise
+/// cause the memory to eventually blow up when the features are not being
+/// removed.
+class RecyclingVector {
+ public:
+ /// By default it does not remove any elements.
+ explicit RecyclingVector(int32_t items_to_hold = -1);
+
+ ~RecyclingVector() = default;
+ RecyclingVector(const RecyclingVector &) = delete;
+ RecyclingVector &operator=(const RecyclingVector &) = delete;
+
+ torch::Tensor At(int32_t index) const;
+
+ void PushBack(torch::Tensor item);
+
+ /// This method returns the size as if no "recycling" had happened,
+ /// i.e. equivalent to the number of times the PushBack method has been
+ /// called.
+ int32_t Size() const;
+
+ private:
+ std::deque items_;
+ int32_t items_to_hold_;
+ int32_t first_available_index_;
+};
+
+/// This is a templated class for online feature extraction;
+/// it's templated on a class like MfccComputer or PlpComputer
+/// that does the basic feature extraction.
+template
+class OnlineGenericBaseFeature : public OnlineFeatureInterface {
+ public:
+ // Constructor from options class
+ explicit OnlineGenericBaseFeature(const typename C::Options &opts);
+
+ int32_t Dim() const override { return computer_.Dim(); }
+
+ float FrameShiftInSeconds() const override {
+ return computer_.GetFrameOptions().frame_shift_ms / 1000.0f;
+ }
+
+ int32_t NumFramesReady() const override { return features_.Size(); }
+
+ // Note: IsLastFrame() will only ever return true if you have called
+ // InputFinished() (and this frame is the last frame).
+ bool IsLastFrame(int32_t frame) const override {
+ return input_finished_ && frame == NumFramesReady() - 1;
+ }
+
+ torch::Tensor GetFrame(int32_t frame) override { return features_.At(frame); }
+
+ // This would be called from the application, when you get
+ // more wave data. Note: the sampling_rate is only provided so
+ // the code can assert that it matches the sampling rate
+ // expected in the options.
+ void AcceptWaveform(float sampling_rate,
+ const torch::Tensor &waveform) override;
+
+ // InputFinished() tells the class you won't be providing any
+ // more waveform. This will help flush out the last frame or two
+ // of features, in the case where snip-edges == false; it also
+ // affects the return value of IsLastFrame().
+ void InputFinished() override;
+
+ private:
+ // This function computes any additional feature frames that it is possible to
+ // compute from 'waveform_remainder_', which at this point may contain more
+ // than just a remainder-sized quantity (because AcceptWaveform() appends to
+ // waveform_remainder_ before calling this function). It adds these feature
+ // frames to features_, and shifts off any now-unneeded samples of input from
+ // waveform_remainder_ while incrementing waveform_offset_ by the same amount.
+ void ComputeFeatures();
+
+ C computer_; // class that does the MFCC or PLP or filterbank computation
+
+ FeatureWindowFunction window_function_;
+
+ // features_ is the Mfcc or Plp or Fbank features that we have already
+ // computed.
+
+ RecyclingVector features_;
+
+ // True if the user has called "InputFinished()"
+ bool input_finished_;
+
+ // waveform_offset_ is the number of samples of waveform that we have
+ // already discarded, i.e. that were prior to 'waveform_remainder_'.
+ int64_t waveform_offset_;
+
+ // waveform_remainder_ is a short piece of waveform that we may need to keep
+ // after extracting all the whole frames we can (whatever length of feature
+ // will be required for the next phase of computation).
+ // It is a 1-D tensor
+ torch::Tensor waveform_remainder_;
+};
+
+using OnlineMfcc = OnlineGenericBaseFeature;
+using OnlinePlp = OnlineGenericBaseFeature;
+using OnlineFbank = OnlineGenericBaseFeature;
+
+} // namespace kaldifeat
+
+#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_H_
diff --git a/kaldifeat/python/csrc/CMakeLists.txt b/kaldifeat/python/csrc/CMakeLists.txt
index affb69c..956263f 100644
--- a/kaldifeat/python/csrc/CMakeLists.txt
+++ b/kaldifeat/python/csrc/CMakeLists.txt
@@ -7,6 +7,7 @@ pybind11_add_module(_kaldifeat
feature-window.cc
kaldifeat.cc
mel-computations.cc
+ online-feature.cc
utils.cc
)
target_link_libraries(_kaldifeat PRIVATE kaldifeat_core)
diff --git a/kaldifeat/python/csrc/feature-window.cc b/kaldifeat/python/csrc/feature-window.cc
index 0fd2ea8..5abaf36 100644
--- a/kaldifeat/python/csrc/feature-window.cc
+++ b/kaldifeat/python/csrc/feature-window.cc
@@ -25,6 +25,7 @@ static void PybindFrameExtractionOptions(py::module &m) {
.def_readwrite("round_to_power_of_two", &PyClass::round_to_power_of_two)
.def_readwrite("blackman_coeff", &PyClass::blackman_coeff)
.def_readwrite("snip_edges", &PyClass::snip_edges)
+ .def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)
.def("as_dict",
[](const PyClass &self) -> py::dict { return AsDict(self); })
.def_static("from_dict",
@@ -35,8 +36,6 @@ static void PybindFrameExtractionOptions(py::module &m) {
.def_readwrite("allow_downsample",
&PyClass::allow_downsample)
.def_readwrite("allow_upsample", &PyClass::allow_upsample)
- .def_readwrite("max_feature_vectors",
- &PyClass::max_feature_vectors)
#endif
.def("__str__",
[](const PyClass &self) -> std::string { return self.ToString(); })
diff --git a/kaldifeat/python/csrc/kaldifeat.cc b/kaldifeat/python/csrc/kaldifeat.cc
index 93e66ac..0a4b8c2 100644
--- a/kaldifeat/python/csrc/kaldifeat.cc
+++ b/kaldifeat/python/csrc/kaldifeat.cc
@@ -11,6 +11,7 @@
#include "kaldifeat/python/csrc/feature-spectrogram.h"
#include "kaldifeat/python/csrc/feature-window.h"
#include "kaldifeat/python/csrc/mel-computations.h"
+#include "kaldifeat/python/csrc/online-feature.h"
#include "torch/torch.h"
namespace kaldifeat {
@@ -24,6 +25,7 @@ PYBIND11_MODULE(_kaldifeat, m) {
PybindFeatureMfcc(m);
PybindFeaturePlp(m);
PybindFeatureSpectrogram(m);
+ PybindOnlineFeature(m);
}
} // namespace kaldifeat
diff --git a/kaldifeat/python/csrc/online-feature.cc b/kaldifeat/python/csrc/online-feature.cc
new file mode 100644
index 0000000..13e4a4f
--- /dev/null
+++ b/kaldifeat/python/csrc/online-feature.cc
@@ -0,0 +1,37 @@
+// kaldifeat/python/csrc/online-feature.cc
+//
+// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
+
+#include "kaldifeat/python/csrc/online-feature.h"
+
+#include
+
+#include "kaldifeat/csrc/online-feature.h"
+namespace kaldifeat {
+
+template
+void PybindOnlineFeatureTpl(py::module &m, const std::string &class_name,
+ const std::string &class_help_doc = "") {
+ using PyClass = OnlineGenericBaseFeature;
+ using Options = typename C::Options;
+ py::class_(m, class_name.c_str(), class_help_doc.c_str())
+ .def(py::init(), py::arg("opts"))
+ .def_property_readonly("dim", &PyClass::Dim)
+ .def_property_readonly("frame_shift_in_seconds",
+ &PyClass::FrameShiftInSeconds)
+ .def_property_readonly("num_frames_ready", &PyClass::NumFramesReady)
+ .def("is_last_frame", &PyClass::IsLastFrame, py::arg("frame"))
+ .def("get_frame", &PyClass::GetFrame, py::arg("frame"))
+ .def("get_frames", &PyClass::GetFrames, py::arg("frames"))
+ .def("accept_waveform", &PyClass::AcceptWaveform,
+ py::arg("sampling_rate"), py::arg("waveform"))
+ .def("input_finished", &PyClass::InputFinished);
+}
+
+void PybindOnlineFeature(py::module &m) {
+ PybindOnlineFeatureTpl(m, "OnlineMfcc");
+ PybindOnlineFeatureTpl(m, "OnlineFbank");
+ PybindOnlineFeatureTpl(m, "OnlinePlp");
+}
+
+} // namespace kaldifeat
diff --git a/kaldifeat/python/csrc/online-feature.h b/kaldifeat/python/csrc/online-feature.h
new file mode 100644
index 0000000..c363f42
--- /dev/null
+++ b/kaldifeat/python/csrc/online-feature.h
@@ -0,0 +1,16 @@
+// kaldifeat/python/csrc/online-feature.h
+//
+// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
+
+#ifndef KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_
+#define KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_
+
+#include "kaldifeat/python/csrc/kaldifeat.h"
+
+namespace kaldifeat {
+
+void PybindOnlineFeature(py::module &m);
+
+} // namespace kaldifeat
+
+#endif // KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_
diff --git a/kaldifeat/python/csrc/utils.cc b/kaldifeat/python/csrc/utils.cc
index 76f47aa..0259afc 100644
--- a/kaldifeat/python/csrc/utils.cc
+++ b/kaldifeat/python/csrc/utils.cc
@@ -30,6 +30,7 @@ FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict) {
FROM_DICT(bool_, round_to_power_of_two);
FROM_DICT(float_, blackman_coeff);
FROM_DICT(bool_, snip_edges);
+ FROM_DICT(int_, max_feature_vectors);
return opts;
}
@@ -47,6 +48,7 @@ py::dict AsDict(const FrameExtractionOptions &opts) {
AS_DICT(round_to_power_of_two);
AS_DICT(blackman_coeff);
AS_DICT(snip_edges);
+ AS_DICT(max_feature_vectors);
return dict;
}
diff --git a/kaldifeat/python/kaldifeat/__init__.py b/kaldifeat/python/kaldifeat/__init__.py
index 6b2f088..60c6443 100644
--- a/kaldifeat/python/kaldifeat/__init__.py
+++ b/kaldifeat/python/kaldifeat/__init__.py
@@ -8,7 +8,7 @@ from _kaldifeat import (
SpectrogramOptions,
)
-from .fbank import Fbank
-from .mfcc import Mfcc
-from .plp import Plp
+from .fbank import Fbank, OnlineFbank
+from .mfcc import Mfcc, OnlineMfcc
+from .plp import OnlinePlp, Plp
from .spectrogram import Spectrogram
diff --git a/kaldifeat/python/kaldifeat/fbank.py b/kaldifeat/python/kaldifeat/fbank.py
index 8f73911..45bc3ef 100644
--- a/kaldifeat/python/kaldifeat/fbank.py
+++ b/kaldifeat/python/kaldifeat/fbank.py
@@ -4,9 +4,20 @@
import _kaldifeat
from .offline_feature import OfflineFeature
+from .online_feature import OnlineFeature
class Fbank(OfflineFeature):
def __init__(self, opts: _kaldifeat.FbankOptions):
super().__init__(opts)
self.computer = _kaldifeat.Fbank(opts)
+
+
+class OnlineFbank(OnlineFeature):
+ def __init__(self, opts: _kaldifeat.FbankOptions):
+ super().__init__(opts)
+ self.computer = _kaldifeat.OnlineFbank(opts)
+
+ def __setstate__(self, state):
+ self.opts = _kaldifeat.FbankOptions.from_dict(state)
+ self.computer = _kaldifeat.OnlineFbank(self.opts)
diff --git a/kaldifeat/python/kaldifeat/mfcc.py b/kaldifeat/python/kaldifeat/mfcc.py
index fa1e225..f76f2f4 100644
--- a/kaldifeat/python/kaldifeat/mfcc.py
+++ b/kaldifeat/python/kaldifeat/mfcc.py
@@ -4,9 +4,20 @@
import _kaldifeat
from .offline_feature import OfflineFeature
+from .online_feature import OnlineFeature
class Mfcc(OfflineFeature):
def __init__(self, opts: _kaldifeat.MfccOptions):
super().__init__(opts)
self.computer = _kaldifeat.Mfcc(opts)
+
+
+class OnlineMfcc(OnlineFeature):
+ def __init__(self, opts: _kaldifeat.MfccOptions):
+ super().__init__(opts)
+ self.computer = _kaldifeat.OnlineMfcc(opts)
+
+ def __setstate__(self, state):
+ self.opts = _kaldifeat.MfccOptions.from_dict(state)
+ self.computer = _kaldifeat.OnlineMfcc(self.opts)
diff --git a/kaldifeat/python/kaldifeat/online_feature.py b/kaldifeat/python/kaldifeat/online_feature.py
new file mode 100644
index 0000000..cf687cf
--- /dev/null
+++ b/kaldifeat/python/kaldifeat/online_feature.py
@@ -0,0 +1,95 @@
+# Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
+
+from typing import List
+
+import torch
+
+
+class OnlineFeature(object):
+ """Offline feature is a base class of other feature computers,
+ e.g., Fbank, Mfcc.
+
+ This class has two fields:
+
+ (1) opts. It contains the options for the feature computer.
+ (2) computer. The actual feature computer. It should be
+ instantiated by subclasses.
+
+ Caution:
+ It supports only CPU at present.
+ """
+
+ def __init__(self, opts):
+ assert opts.device.type == "cpu"
+
+ self.opts = opts
+
+ # self.computer is expected to be set by subclasses
+ self.computer = None
+
+ @property
+ def num_frames_ready(self) -> int:
+ """Return the number of ready frames.
+
+ It can be updated by :method:`accept_waveform`.
+
+ Note:
+ If you set ``opts.frame_opts.max_feature_vectors``, then
+ the valid frame indexes are in the range.
+ ``[num_frames_ready - max_feature_vectors, num_frames_ready)``
+
+ If you leave ``opts.frame_opts.max_feature_vectors`` to its default
+ value, then the range is ``[0, num_frames_ready)``
+ """
+ return self.computer.num_frames_ready
+
+ def is_last_frame(self, frame: int) -> bool:
+ """Return True if the given frame is the last frame."""
+ return self.computer.is_last_frame(frame)
+
+ def get_frame(self, frame: int) -> torch.Tensor:
+ """Get the frame by its index.
+ Args:
+ frame:
+ The frame index. If ``opts.frame_opts.max_feature_vectors`` is
+ -1, then its valid values are in the range
+ ``[0, num_frames_ready)``. Otherwise, the range is
+ ``[num_frames_ready - max_feature_vectors, num_frames_ready)``.
+ Returns:
+ Return a 2-D tensor with shape ``(1, feature_dim)``
+ """
+ return self.computer.get_frame(frame)
+
+ def get_frames(self, frames: List[int]) -> List[torch.Tensor]:
+ """Get frames at the given frame indexes.
+ Args:
+ frames:
+ Frames whose indexes are in this list are returned.
+ Returns:
+ Return a list of feature frames at the given indexes.
+ """
+ return self.computer.get_frames(frames)
+
+ def accept_waveform(
+ self, sampling_rate: float, waveform: torch.Tensor
+ ) -> None:
+ """Send audio samples to the extractor.
+ Args:
+ sampling_rate:
+ The sampling rate of the given audio samples. It has to be equal
+ to ``opts.frame_opts.samp_freq``.
+ waveform:
+ A 1-D tensor of shape (num_samples,). Its dtype is torch.float32
+ and has to be on CPU.
+ """
+ self.computer.accept_waveform(sampling_rate, waveform)
+
+ def input_finished(self) -> None:
+ """Tell the extractor that no more audio samples will be available.
+ After calling this function, you cannot invoke ``accept_waveform``
+ again.
+ """
+ self.computer.input_finished()
+
+ def __getstate__(self):
+ return self.opts.as_dict()
diff --git a/kaldifeat/python/kaldifeat/plp.py b/kaldifeat/python/kaldifeat/plp.py
index 219e2d4..d99dbc2 100644
--- a/kaldifeat/python/kaldifeat/plp.py
+++ b/kaldifeat/python/kaldifeat/plp.py
@@ -4,9 +4,20 @@
import _kaldifeat
from .offline_feature import OfflineFeature
+from .online_feature import OnlineFeature
class Plp(OfflineFeature):
def __init__(self, opts: _kaldifeat.PlpOptions):
super().__init__(opts)
self.computer = _kaldifeat.Plp(opts)
+
+
+class OnlinePlp(OnlineFeature):
+ def __init__(self, opts: _kaldifeat.PlpOptions):
+ super().__init__(opts)
+ self.computer = _kaldifeat.OnlinePlp(opts)
+
+ def __setstate__(self, state):
+ self.opts = _kaldifeat.PlpOptions.from_dict(state)
+ self.computer = _kaldifeat.OnlinePlp(self.opts)
diff --git a/kaldifeat/python/tests/test_fbank.py b/kaldifeat/python/tests/test_fbank.py
index 57092b6..1c06438 100755
--- a/kaldifeat/python/tests/test_fbank.py
+++ b/kaldifeat/python/tests/test_fbank.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
-# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang)
+# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang)
import pickle
from pathlib import Path
@@ -13,30 +13,88 @@ import kaldifeat
cur_dir = Path(__file__).resolve().parent
+def test_online_fbank(
+ opts: kaldifeat.FbankOptions,
+ wave: torch.Tensor,
+ cpu_features: torch.Tensor,
+):
+ """
+ Args:
+ opts:
+ The options to create the online fbank extractor.
+ wave:
+ The input 1-D waveform.
+ cpu_features:
+ The groud truth features that are computed offline
+ """
+ online_fbank = kaldifeat.OnlineFbank(opts)
+
+ num_processed_frames = 0
+ i = 0 # current sample index to feed
+ while not online_fbank.is_last_frame(num_processed_frames - 1):
+ while num_processed_frames < online_fbank.num_frames_ready:
+ # There are new frames to be processed
+ frame = online_fbank.get_frame(num_processed_frames)
+ assert torch.allclose(
+ frame.squeeze(0), cpu_features[num_processed_frames]
+ )
+ num_processed_frames += 1
+
+ # Simulate streaming . Send a random number of audio samples
+ # to the extractor
+ num_samples = torch.randint(300, 1000, (1,)).item()
+
+ samples = wave[i : (i + num_samples)] # noqa
+ i += num_samples
+ if len(samples) == 0:
+ online_fbank.input_finished()
+ continue
+
+ online_fbank.accept_waveform(16000, samples)
+
+ assert num_processed_frames == online_fbank.num_frames_ready
+ assert num_processed_frames == cpu_features.size(0)
+
+
def test_fbank_default():
print("=====test_fbank_default=====")
+ filename = cur_dir / "test_data/test.wav"
+ wave = read_wave(filename)
+ gt = read_ark_txt(cur_dir / "test_data/test.txt")
+
+ cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
fbank = kaldifeat.Fbank(opts)
- filename = cur_dir / "test_data/test.wav"
- wave = read_wave(filename)
features = fbank(wave)
assert features.device.type == "cpu"
- gt = read_ark_txt(cur_dir / "test_data/test.txt")
assert torch.allclose(features, gt, rtol=1e-1)
+ if cpu_features is None:
+ cpu_features = features
- wave = wave.to(device)
- features = fbank(wave)
+ features = fbank(wave.to(device))
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
+ # Now for online fbank
+ opts = kaldifeat.FbankOptions()
+ opts.frame_opts.dither = 0
+ opts.frame_opts.max_feature_vectors = 100
+
+ test_online_fbank(opts, wave, cpu_features)
+
def test_fbank_htk():
print("=====test_fbank_htk=====")
+ filename = cur_dir / "test_data/test.wav"
+ wave = read_wave(filename)
+ gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
+
+ cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.FbankOptions()
@@ -46,22 +104,32 @@ def test_fbank_htk():
opts.htk_compat = True
fbank = kaldifeat.Fbank(opts)
- filename = cur_dir / "test_data/test.wav"
- wave = read_wave(filename)
features = fbank(wave)
assert features.device.type == "cpu"
- gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
assert torch.allclose(features, gt, rtol=1e-1)
+ if cpu_features is None:
+ cpu_features = features
- wave = wave.to(device)
- features = fbank(wave)
+ features = fbank(wave.to(device))
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
+ opts = kaldifeat.FbankOptions()
+ opts.frame_opts.dither = 0
+ opts.use_energy = True
+ opts.htk_compat = True
+
+ test_online_fbank(opts, wave, cpu_features)
+
def test_fbank_with_energy():
print("=====test_fbank_with_energy=====")
+ filename = cur_dir / "test_data/test.wav"
+ wave = read_wave(filename)
+ gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
+
+ cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.FbankOptions()
@@ -70,22 +138,31 @@ def test_fbank_with_energy():
opts.use_energy = True
fbank = kaldifeat.Fbank(opts)
- filename = cur_dir / "test_data/test.wav"
- wave = read_wave(filename)
features = fbank(wave)
- gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
assert torch.allclose(features, gt, rtol=1e-1)
assert features.device.type == "cpu"
+ if cpu_features is None:
+ cpu_features = features
- wave = wave.to(device)
- features = fbank(wave)
+ features = fbank(wave.to(device))
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
+ opts = kaldifeat.FbankOptions()
+ opts.frame_opts.dither = 0
+ opts.use_energy = True
+
+ test_online_fbank(opts, wave, cpu_features)
+
def test_fbank_40_bins():
print("=====test_fbank_40_bins=====")
+ filename = cur_dir / "test_data/test.wav"
+ wave = read_wave(filename)
+ gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
+
+ cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.FbankOptions()
@@ -94,22 +171,31 @@ def test_fbank_40_bins():
opts.mel_opts.num_bins = 40
fbank = kaldifeat.Fbank(opts)
- filename = cur_dir / "test_data/test.wav"
- wave = read_wave(filename)
features = fbank(wave)
assert features.device.type == "cpu"
- gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
assert torch.allclose(features, gt, rtol=1e-1)
+ if cpu_features is None:
+ cpu_features = features
- wave = wave.to(device)
- features = fbank(wave)
+ features = fbank(wave.to(device))
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
+ opts = kaldifeat.FbankOptions()
+ opts.frame_opts.dither = 0
+ opts.mel_opts.num_bins = 40
+
+ test_online_fbank(opts, wave, cpu_features)
+
def test_fbank_40_bins_no_snip_edges():
print("=====test_fbank_40_bins_no_snip_edges=====")
+ filename = cur_dir / "test_data/test.wav"
+ wave = read_wave(filename)
+ gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
+
+ cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.FbankOptions()
@@ -119,19 +205,24 @@ def test_fbank_40_bins_no_snip_edges():
opts.frame_opts.snip_edges = False
fbank = kaldifeat.Fbank(opts)
- filename = cur_dir / "test_data/test.wav"
- wave = read_wave(filename)
features = fbank(wave)
assert features.device.type == "cpu"
- gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
assert torch.allclose(features, gt, rtol=1e-1)
+ if cpu_features is None:
+ cpu_features = features
- wave = wave.to(device)
- features = fbank(wave)
+ features = fbank(wave.to(device))
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
+ opts = kaldifeat.FbankOptions()
+ opts.frame_opts.dither = 0
+ opts.mel_opts.num_bins = 40
+ opts.frame_opts.snip_edges = False
+
+ test_online_fbank(opts, wave, cpu_features)
+
def test_fbank_chunk():
print("=====test_fbank_chunk=====")
@@ -223,6 +314,16 @@ def test_pickle():
assert str(fbank.opts) == str(fbank2.opts)
+ opts = kaldifeat.FbankOptions()
+ opts.use_energy = True
+ opts.use_power = False
+
+ fbank = kaldifeat.OnlineFbank(opts)
+ data = pickle.dumps(fbank)
+ fbank2 = pickle.loads(data)
+
+ assert str(fbank.opts) == str(fbank2.opts)
+
if __name__ == "__main__":
test_fbank_default()
diff --git a/kaldifeat/python/tests/test_mfcc.py b/kaldifeat/python/tests/test_mfcc.py
index 33407b5..5665da4 100755
--- a/kaldifeat/python/tests/test_mfcc.py
+++ b/kaldifeat/python/tests/test_mfcc.py
@@ -13,24 +13,82 @@ import kaldifeat
cur_dir = Path(__file__).resolve().parent
+def test_online_mfcc(
+ opts: kaldifeat.MfccOptions,
+ wave: torch.Tensor,
+ cpu_features: torch.Tensor,
+):
+ """
+ Args:
+ opts:
+ The options to create the online mfcc extractor.
+ wave:
+ The input 1-D waveform.
+ cpu_features:
+ The groud truth features that are computed offline
+ """
+ online_mfcc = kaldifeat.OnlineMfcc(opts)
+
+ num_processed_frames = 0
+ i = 0 # current sample index to feed
+ while not online_mfcc.is_last_frame(num_processed_frames - 1):
+ while num_processed_frames < online_mfcc.num_frames_ready:
+ # There are new frames to be processed
+ frame = online_mfcc.get_frame(num_processed_frames)
+ assert torch.allclose(
+ frame.squeeze(0), cpu_features[num_processed_frames], atol=1e-3
+ )
+ num_processed_frames += 1
+
+ # Simulate streaming . Send a random number of audio samples
+ # to the extractor
+ num_samples = torch.randint(300, 1000, (1,)).item()
+
+ samples = wave[i : (i + num_samples)] # noqa
+ i += num_samples
+ if len(samples) == 0:
+ online_mfcc.input_finished()
+ continue
+
+ online_mfcc.accept_waveform(16000, samples)
+
+ assert num_processed_frames == online_mfcc.num_frames_ready
+ assert num_processed_frames == cpu_features.size(0)
+
+
def test_mfcc_default():
print("=====test_mfcc_default=====")
+ filename = cur_dir / "test_data/test.wav"
+ wave = read_wave(filename)
+ gt = read_ark_txt(cur_dir / "test_data/test-mfcc.txt")
+
+ cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.MfccOptions()
opts.device = device
opts.frame_opts.dither = 0
mfcc = kaldifeat.Mfcc(opts)
- filename = cur_dir / "test_data/test.wav"
- wave = read_wave(filename).to(device)
- features = mfcc(wave)
- gt = read_ark_txt(cur_dir / "test_data/test-mfcc.txt")
+ features = mfcc(wave.to(device))
+ if device.type == "cpu":
+ cpu_features = features
+
assert torch.allclose(features.cpu(), gt, atol=1e-1)
+ opts = kaldifeat.MfccOptions()
+ opts.frame_opts.dither = 0
+
+ test_online_mfcc(opts, wave, cpu_features)
+
def test_mfcc_no_snip_edges():
print("=====test_mfcc_no_snip_edges=====")
+ filename = cur_dir / "test_data/test.wav"
+ wave = read_wave(filename)
+ gt = read_ark_txt(cur_dir / "test_data/test-mfcc-no-snip-edges.txt")
+
+ cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.MfccOptions()
@@ -39,13 +97,19 @@ def test_mfcc_no_snip_edges():
opts.frame_opts.snip_edges = False
mfcc = kaldifeat.Mfcc(opts)
- filename = cur_dir / "test_data/test.wav"
- wave = read_wave(filename).to(device)
- features = mfcc(wave)
- gt = read_ark_txt(cur_dir / "test_data/test-mfcc-no-snip-edges.txt")
+ features = mfcc(wave.to(device))
+ if device.type == "cpu":
+ cpu_features = features
+
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
+ opts = kaldifeat.MfccOptions()
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+
+ test_online_mfcc(opts, wave, cpu_features)
+
def test_pickle():
for device in get_devices():
@@ -60,6 +124,16 @@ def test_pickle():
assert str(mfcc.opts) == str(mfcc2.opts)
+ opts = kaldifeat.MfccOptions()
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+
+ mfcc = kaldifeat.OnlineMfcc(opts)
+ data = pickle.dumps(mfcc)
+ mfcc2 = pickle.loads(data)
+
+ assert str(mfcc.opts) == str(mfcc2.opts)
+
if __name__ == "__main__":
test_mfcc_default()
diff --git a/kaldifeat/python/tests/test_plp.py b/kaldifeat/python/tests/test_plp.py
index 4f20452..cf56d41 100755
--- a/kaldifeat/python/tests/test_plp.py
+++ b/kaldifeat/python/tests/test_plp.py
@@ -13,24 +13,82 @@ import kaldifeat
cur_dir = Path(__file__).resolve().parent
+def test_online_plp(
+ opts: kaldifeat.PlpOptions,
+ wave: torch.Tensor,
+ cpu_features: torch.Tensor,
+):
+ """
+ Args:
+ opts:
+ The options to create the online plp extractor.
+ wave:
+ The input 1-D waveform.
+ cpu_features:
+ The groud truth features that are computed offline
+ """
+ online_plp = kaldifeat.OnlinePlp(opts)
+
+ num_processed_frames = 0
+ i = 0 # current sample index to feed
+ while not online_plp.is_last_frame(num_processed_frames - 1):
+ while num_processed_frames < online_plp.num_frames_ready:
+ # There are new frames to be processed
+ frame = online_plp.get_frame(num_processed_frames)
+ assert torch.allclose(
+ frame.squeeze(0), cpu_features[num_processed_frames], atol=1e-3
+ )
+ num_processed_frames += 1
+
+ # Simulate streaming . Send a random number of audio samples
+ # to the extractor
+ num_samples = torch.randint(300, 1000, (1,)).item()
+
+ samples = wave[i : (i + num_samples)] # noqa
+ i += num_samples
+ if len(samples) == 0:
+ online_plp.input_finished()
+ continue
+
+ online_plp.accept_waveform(16000, samples)
+
+ assert num_processed_frames == online_plp.num_frames_ready
+ assert num_processed_frames == cpu_features.size(0)
+
+
def test_plp_default():
print("=====test_plp_default=====")
+ filename = cur_dir / "test_data/test.wav"
+ wave = read_wave(filename)
+ gt = read_ark_txt(cur_dir / "test_data/test-plp.txt")
+
+ cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.PlpOptions()
opts.frame_opts.dither = 0
opts.device = device
plp = kaldifeat.Plp(opts)
- filename = cur_dir / "test_data/test.wav"
- wave = read_wave(filename).to(device)
- features = plp(wave)
- gt = read_ark_txt(cur_dir / "test_data/test-plp.txt")
+ features = plp(wave.to(device))
+ if device.type == "cpu":
+ cpu_features = features
+
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
+ opts = kaldifeat.PlpOptions()
+ opts.frame_opts.dither = 0
+
+ test_online_plp(opts, wave, cpu_features)
+
def test_plp_no_snip_edges():
print("=====test_plp_no_snip_edges=====")
+ filename = cur_dir / "test_data/test.wav"
+ wave = read_wave(filename)
+ gt = read_ark_txt(cur_dir / "test_data/test-plp-no-snip-edges.txt")
+
+ cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.PlpOptions()
@@ -39,16 +97,26 @@ def test_plp_no_snip_edges():
opts.frame_opts.snip_edges = False
plp = kaldifeat.Plp(opts)
- filename = cur_dir / "test_data/test.wav"
- wave = read_wave(filename).to(device)
- features = plp(wave)
- gt = read_ark_txt(cur_dir / "test_data/test-plp-no-snip-edges.txt")
+ features = plp(wave.to(device))
+ if device.type == "cpu":
+ cpu_features = features
assert torch.allclose(features.cpu(), gt, atol=1e-1)
+ opts = kaldifeat.PlpOptions()
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+
+ test_online_plp(opts, wave, cpu_features)
+
def test_plp_htk_10_ceps():
print("=====test_plp_htk_10_ceps=====")
+ filename = cur_dir / "test_data/test.wav"
+ wave = read_wave(filename)
+ gt = read_ark_txt(cur_dir / "test_data/test-plp-htk-10-ceps.txt")
+
+ cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.PlpOptions()
@@ -58,13 +126,19 @@ def test_plp_htk_10_ceps():
opts.frame_opts.dither = 0
plp = kaldifeat.Plp(opts)
- filename = cur_dir / "test_data/test.wav"
- wave = read_wave(filename).to(device)
- features = plp(wave)
- gt = read_ark_txt(cur_dir / "test_data/test-plp-htk-10-ceps.txt")
+ features = plp(wave.to(device))
+ if device.type == "cpu":
+ cpu_features = features
assert torch.allclose(features.cpu(), gt, atol=1e-1)
+ opts = kaldifeat.PlpOptions()
+ opts.htk_compat = True
+ opts.num_ceps = 10
+ opts.frame_opts.dither = 0
+
+ test_online_plp(opts, wave, cpu_features)
+
def test_pickle():
for device in get_devices():
@@ -79,6 +153,16 @@ def test_pickle():
assert str(plp.opts) == str(plp2.opts)
+ opts = kaldifeat.PlpOptions()
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+
+ plp = kaldifeat.OnlinePlp(opts)
+ data = pickle.dumps(plp)
+ plp2 = pickle.loads(data)
+
+ assert str(plp.opts) == str(plp2.opts)
+
if __name__ == "__main__":
test_plp_default()