diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index e38f9d0..462a2db 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -45,7 +45,9 @@ jobs: - name: Install Python dependencies run: | - python3 -m pip install --upgrade pip black flake8 + python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4 + # See https://github.com/psf/black/issues/2964 + # The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4 - name: Run flake8 shell: bash diff --git a/README.md b/README.md index 841af4d..e4effa5 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,17 @@ features = fbank(wave) + +Streaming FBANK +kaldifeat.FbankOptions +kaldifeat.OnlineFbank + +See +./kaldifeat/python/tests/test_fbank.py + + + + MFCC kaldifeat.MfccOptions @@ -45,6 +56,17 @@ features = mfcc(wave) + +Streaming MFCC +kaldifeat.MfccOptions +kaldifeat.OnlineMfcc + +See +./kaldifeat/python/tests/test_mfcc.py + + + + PLP kaldifeat.PlpOptions @@ -59,6 +81,17 @@ features = plp(wave) + +Streaming PLP +kaldifeat.PlpOptions +kaldifeat.OnlinePlp + +See +./kaldifeat/python/tests/test_plp.py + + + + Spectorgram kaldifeat.SpectrogramOptions @@ -88,6 +121,8 @@ The following kaldi-compatible commandline tools are implemented: (**NOTE**: We will implement other types of features, e.g., Pitch, ivector, etc, soon.) +**HINT**: It supports also streaming feature extractors for Fbank, MFCC, and Plp. + # Usage Let us first generate a test wave using sox: diff --git a/kaldifeat/csrc/CMakeLists.txt b/kaldifeat/csrc/CMakeLists.txt index 8dce57c..7e6f943 100644 --- a/kaldifeat/csrc/CMakeLists.txt +++ b/kaldifeat/csrc/CMakeLists.txt @@ -9,6 +9,7 @@ set(kaldifeat_srcs feature-window.cc matrix-functions.cc mel-computations.cc + online-feature.cc ) add_library(kaldifeat_core SHARED ${kaldifeat_srcs}) @@ -40,6 +41,7 @@ if(kaldifeat_BUILD_TESTS) # please sort the source files alphabetically set(test_srcs feature-window-test.cc + online-feature-test.cc ) foreach(source IN LISTS test_srcs) diff --git a/kaldifeat/csrc/feature-common.h b/kaldifeat/csrc/feature-common.h index 5710c22..24e7ec1 100644 --- a/kaldifeat/csrc/feature-common.h +++ b/kaldifeat/csrc/feature-common.h @@ -62,6 +62,10 @@ class OfflineFeatureTpl { int32_t Dim() const { return computer_.Dim(); } const Options &GetOptions() const { return computer_.GetOptions(); } + const FrameExtractionOptions &GetFrameOptions() const { + return GetOptions().frame_opts; + } + // Copy constructor. OfflineFeatureTpl(const OfflineFeatureTpl &) = delete; OfflineFeatureTpl &operator=(const OfflineFeatureTpl &) = delete; diff --git a/kaldifeat/csrc/feature-window.cc b/kaldifeat/csrc/feature-window.cc index 4ec5ac2..6880f7e 100644 --- a/kaldifeat/csrc/feature-window.cc +++ b/kaldifeat/csrc/feature-window.cc @@ -161,19 +161,20 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value) { #if 1 return wave + rand_gauss * dither_value; #else - // use in-place version of wave and change its to pointer type + // use in-place version of wave and change it to pointer type wave_->add_(rand_gauss, dither_value); #endif } torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) { - using namespace torch::indexing; // It imports: Slice, None // NOLINT if (preemph_coeff == 0.0f) return wave; KALDIFEAT_ASSERT(preemph_coeff >= 0.0f && preemph_coeff <= 1.0f); torch::Tensor ans = torch::empty_like(wave); + using torch::indexing::None; + using torch::indexing::Slice; // right = wave[:, 1:] torch::Tensor right = wave.index({"...", Slice(1, None, None)}); @@ -188,4 +189,59 @@ torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) { return ans; } +torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave, + int32_t f, const FrameExtractionOptions &opts) { + KALDIFEAT_ASSERT(sample_offset >= 0 && wave.numel() != 0); + + int32_t frame_length = opts.WindowSize(); + int64_t num_samples = sample_offset + wave.numel(); + int64_t start_sample = FirstSampleOfFrame(f, opts); + int64_t end_sample = start_sample + frame_length; + + if (opts.snip_edges) { + KALDIFEAT_ASSERT(start_sample >= sample_offset && + end_sample <= num_samples); + } else { + KALDIFEAT_ASSERT(sample_offset == 0 || start_sample >= sample_offset); + } + + // wave_start and wave_end are start and end indexes into 'wave', for the + // piece of wave that we're trying to extract. + int32_t wave_start = static_cast(start_sample - sample_offset); + int32_t wave_end = wave_start + frame_length; + + if (wave_start >= 0 && wave_end <= wave.numel()) { + // the normal case -- no edge effects to consider. + // return wave[wave_start:wave_end] + return wave.index({torch::indexing::Slice(wave_start, wave_end)}); + } else { + torch::Tensor window = torch::empty({frame_length}, torch::kFloat); + auto p_window = window.accessor(); + auto p_wave = wave.accessor(); + + // Deal with any end effects by reflection, if needed. This code will only + // be reached for about two frames per utterance, so we don't concern + // ourselves excessively with efficiency. + int32_t wave_dim = wave.numel(); + for (int32_t s = 0; s != frame_length; ++s) { + int32_t s_in_wave = s + wave_start; + while (s_in_wave < 0 || s_in_wave >= wave_dim) { + // reflect around the beginning or end of the wave. + // e.g. -1 -> 0, -2 -> 1. + // dim -> dim - 1, dim + 1 -> dim - 2. + // the code supports repeated reflections, although this + // would only be needed in pathological cases. + if (s_in_wave < 0) { + s_in_wave = -s_in_wave - 1; + } else { + s_in_wave = 2 * wave_dim - 1 - s_in_wave; + } + } + + p_window[s] = p_wave[s_in_wave]; + } + return window; + } +} + } // namespace kaldifeat diff --git a/kaldifeat/csrc/feature-window.h b/kaldifeat/csrc/feature-window.h index 255c274..26d743a 100644 --- a/kaldifeat/csrc/feature-window.h +++ b/kaldifeat/csrc/feature-window.h @@ -44,7 +44,11 @@ struct FrameExtractionOptions { bool snip_edges = true; // bool allow_downsample = false; // bool allow_upsample = false; - // int32_t max_feature_vectors = -1; + + // Used for streaming feature extraction. It indicates the number + // of feature frames to keep in the recycling vector. -1 means to + // keep all feature frames. + int32_t max_feature_vectors = -1; int32_t WindowShift() const { return static_cast(samp_freq * 0.001f * frame_shift_ms); @@ -71,7 +75,7 @@ struct FrameExtractionOptions { KALDIFEAT_PRINT(snip_edges); // KALDIFEAT_PRINT(allow_downsample); // KALDIFEAT_PRINT(allow_upsample); - // KALDIFEAT_PRINT(max_feature_vectors); + KALDIFEAT_PRINT(max_feature_vectors); #undef KALDIFEAT_PRINT return os.str(); } @@ -100,11 +104,11 @@ class FeatureWindowFunction { @param [in] flush True if we are asserting that this number of samples is 'all there is', false if we expecting more data to possibly come in. This - only makes a difference to the answer if opts.snips_edges - == false. For offline feature extraction you always want flush == - true. In an online-decoding context, once you know (or decide) - that no more data is coming in, you'd call it with flush == true at the end - to flush out any remaining data. + only makes a difference to the answer + if opts.snips_edges== false. For offline feature extraction you always want + flush == true. In an online-decoding context, once you know (or decide) that + no more data is coming in, you'd call it with flush == true at the end to + flush out any remaining data. */ int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts, bool flush = true); @@ -133,6 +137,29 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value); torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave); +/* + ExtractWindow() extracts "frame_length" samples from the given waveform. + Note: This function only extracts "frame_length" samples + from the input waveform, without any further processing. + + @param [in] sample_offset If 'wave' is not the entire waveform, but + part of it to the left has been discarded, then the + number of samples prior to 'wave' that we have + already discarded. Set this to zero if you are + processing the entire waveform in one piece, or + if you get 'no matching function' compilation + errors when updating the code. + @param [in] wave The waveform + @param [in] f The frame index to be extracted, with + 0 <= f < NumFrames(sample_offset + wave.numel(), opts, true) + @param [in] opts The options class to be used + @return Return a tensor containing "frame_length" samples extracted from + `wave`, without any further processing. Its shape is + (1, frame_length). +*/ +torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave, + int32_t f, const FrameExtractionOptions &opts); + } // namespace kaldifeat #endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_ diff --git a/kaldifeat/csrc/online-feature-itf.h b/kaldifeat/csrc/online-feature-itf.h new file mode 100644 index 0000000..835e182 --- /dev/null +++ b/kaldifeat/csrc/online-feature-itf.h @@ -0,0 +1,89 @@ +// kaldifeat/csrc/online-feature-itf.h +// +// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +// This file is copied/modified from kaldi/src/itf/online-feature-itf.h + +#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_ +#define KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_ + +#include +#include + +#include "torch/script.h" + +namespace kaldifeat { + +class OnlineFeatureInterface { + public: + virtual ~OnlineFeatureInterface() = default; + + virtual int32_t Dim() const = 0; /// returns the feature dimension. + // + // Returns frame shift in seconds. Helps to estimate duration from frame + // counts. + virtual float FrameShiftInSeconds() const = 0; + + /// Returns the total number of frames, since the start of the utterance, that + /// are now available. In an online-decoding context, this will likely + /// increase with time as more data becomes available. + virtual int32_t NumFramesReady() const = 0; + + /// Returns true if this is the last frame. Frame indices are zero-based, so + /// the first frame is zero. IsLastFrame(-1) will return false, unless the + /// file is empty (which is a case that I'm not sure all the code will handle, + /// so be careful). This function may return false for some frame if we + /// haven't yet decided to terminate decoding, but later true if we decide to + /// terminate decoding. This function exists mainly to correctly handle end + /// effects in feature extraction, and is not a mechanism to determine how + /// many frames are in the decodable object (as it used to be, and for + /// backward compatibility, still is, in the Decodable interface). + virtual bool IsLastFrame(int32_t frame) const = 0; + + /// Gets the feature vector for this frame. Before calling this for a given + /// frame, it is assumed that you called NumFramesReady() and it returned a + /// number greater than "frame". Otherwise this call will likely crash with + /// an assert failure. This function is not declared const, in case there is + /// some kind of caching going on, but most of the time it shouldn't modify + /// the class. + /// + /// The returned tensor has shape (1, Dim()). + virtual torch::Tensor GetFrame(int32_t frame) = 0; + + /// This is like GetFrame() but for a collection of frames. There is a + /// default implementation that just gets the frames one by one, but it + /// may be overridden for efficiency by child classes (since sometimes + /// it's more efficient to do things in a batch). + /// + /// The returned tensor has shape (frames.size(), Dim()). + virtual std::vector GetFrames( + const std::vector &frames) { + std::vector features; + features.reserve(frames.size()); + + for (auto i : frames) { + torch::Tensor f = GetFrame(i); + features.push_back(std::move(f)); + } + return features; +#if 0 + return torch::cat(features, /*dim*/ 0); +#endif + } + + /// This would be called from the application, when you get more wave data. + /// Note: the sampling_rate is typically only provided so the code can assert + /// that it matches the sampling rate expected in the options. + virtual void AcceptWaveform(float sampling_rate, + const torch::Tensor &waveform) = 0; + + /// InputFinished() tells the class you won't be providing any + /// more waveform. This will help flush out the last few frames + /// of delta or LDA features (it will typically affect the return value + /// of IsLastFrame. + virtual void InputFinished() = 0; +}; + +} // namespace kaldifeat + +#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_ diff --git a/kaldifeat/csrc/online-feature-test.cc b/kaldifeat/csrc/online-feature-test.cc new file mode 100644 index 0000000..786c1c1 --- /dev/null +++ b/kaldifeat/csrc/online-feature-test.cc @@ -0,0 +1,49 @@ +// kaldifeat/csrc/online-feature-test.h +// +// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +#include "kaldifeat/csrc/online-feature.h" + +#include "gtest/gtest.h" + +namespace kaldifeat { + +TEST(RecyclingVector, TestUnlimited) { + RecyclingVector v(-1); + constexpr int32_t N = 100; + for (int32_t i = 0; i != N; ++i) { + torch::Tensor t = torch::tensor({i, i + 1, i + 2}); + v.PushBack(t); + } + ASSERT_EQ(v.Size(), N); + + for (int32_t i = 0; i != N; ++i) { + torch::Tensor t = v.At(i); + torch::Tensor expected = torch::tensor({i, i + 1, i + 2}); + EXPECT_TRUE(t.equal(expected)); + } +} + +TEST(RecyclingVector, Testlimited) { + constexpr int32_t K = 3; + constexpr int32_t N = 10; + RecyclingVector v(K); + for (int32_t i = 0; i != N; ++i) { + torch::Tensor t = torch::tensor({i, i + 1, i + 2}); + v.PushBack(t); + } + + ASSERT_EQ(v.Size(), N); + + for (int32_t i = 0; i < N - K; ++i) { + ASSERT_DEATH(v.At(i), ""); + } + + for (int32_t i = N - K; i != N; ++i) { + torch::Tensor t = v.At(i); + torch::Tensor expected = torch::tensor({i, i + 1, i + 2}); + EXPECT_TRUE(t.equal(expected)); + } +} + +} // namespace kaldifeat diff --git a/kaldifeat/csrc/online-feature.cc b/kaldifeat/csrc/online-feature.cc new file mode 100644 index 0000000..43fc1b1 --- /dev/null +++ b/kaldifeat/csrc/online-feature.cc @@ -0,0 +1,133 @@ +// kaldifeat/csrc/online-feature.cc +// +// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +// This file is copied/modified from kaldi/src/feat/online-feature.cc + +#include "kaldifeat/csrc/online-feature.h" + +#include "kaldifeat/csrc/feature-window.h" +#include "kaldifeat/csrc/log.h" + +namespace kaldifeat { + +RecyclingVector::RecyclingVector(int32_t items_to_hold) + : items_to_hold_(items_to_hold == 0 ? -1 : items_to_hold), + first_available_index_(0) {} + +torch::Tensor RecyclingVector::At(int32_t index) const { + if (index < first_available_index_) { + KALDIFEAT_ERR << "Attempted to retrieve feature vector that was " + "already removed by the RecyclingVector (index = " + << index << "; " + << "first_available_index = " << first_available_index_ + << "; " + << "size = " << Size() << ")"; + } + // 'at' does size checking. + return items_.at(index - first_available_index_); +} + +void RecyclingVector::PushBack(torch::Tensor item) { + // Note: -1 is a larger number when treated as unsigned + if (items_.size() == static_cast(items_to_hold_)) { + items_.pop_front(); + ++first_available_index_; + } + items_.push_back(item); +} + +int32_t RecyclingVector::Size() const { + return first_available_index_ + static_cast(items_.size()); +} + +template +OnlineGenericBaseFeature::OnlineGenericBaseFeature( + const typename C::Options &opts) + : computer_(opts), + window_function_(opts.frame_opts, opts.device), + features_(opts.frame_opts.max_feature_vectors), + input_finished_(false), + waveform_offset_(0) {} + +template +void OnlineGenericBaseFeature::AcceptWaveform( + float sampling_rate, const torch::Tensor &original_waveform) { + if (original_waveform.numel() == 0) return; // Nothing to do. + + KALDIFEAT_ASSERT(original_waveform.dim() == 1); + KALDIFEAT_ASSERT(sampling_rate == computer_.GetFrameOptions().samp_freq); + + if (input_finished_) + KALDIFEAT_ERR << "AcceptWaveform called after InputFinished() was called."; + + if (waveform_remainder_.numel() == 0) { + waveform_remainder_ = original_waveform; + } else { + waveform_remainder_ = + torch::cat({waveform_remainder_, original_waveform}, /*dim*/ 0); + } + + ComputeFeatures(); +} + +template +void OnlineGenericBaseFeature::InputFinished() { + input_finished_ = true; + ComputeFeatures(); +} + +template +void OnlineGenericBaseFeature::ComputeFeatures() { + const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions(); + + int64_t num_samples_total = waveform_offset_ + waveform_remainder_.numel(); + int32_t num_frames_old = features_.Size(); + int32_t num_frames_new = + NumFrames(num_samples_total, frame_opts, input_finished_); + + KALDIFEAT_ASSERT(num_frames_new >= num_frames_old); + + // note: this online feature-extraction code does not support VTLN. + float vtln_warp = 1.0; + + for (int32_t frame = num_frames_old; frame < num_frames_new; ++frame) { + torch::Tensor window = + ExtractWindow(waveform_offset_, waveform_remainder_, frame, frame_opts); + + // TODO(fangjun): We can compute all feature frames at once + torch::Tensor this_feature = + computer_.ComputeFeatures(window.unsqueeze(0), vtln_warp); + features_.PushBack(this_feature); + } + + // OK, we will now discard any portion of the signal that will not be + // necessary to compute frames in the future. + int64_t first_sample_of_next_frame = + FirstSampleOfFrame(num_frames_new, frame_opts); + int32_t samples_to_discard = first_sample_of_next_frame - waveform_offset_; + if (samples_to_discard > 0) { + // discard the leftmost part of the waveform that we no longer need. + int32_t new_num_samples = waveform_remainder_.numel() - samples_to_discard; + if (new_num_samples <= 0) { + // odd, but we'll try to handle it. + waveform_offset_ += waveform_remainder_.numel(); + waveform_remainder_.resize_({0}); + } else { + using torch::indexing::None; + using torch::indexing::Slice; + + waveform_remainder_ = + waveform_remainder_.index({Slice(samples_to_discard, None)}); + + waveform_offset_ += samples_to_discard; + } + } +} + +// instantiate the templates defined here for MFCC, PLP and filterbank classes. +template class OnlineGenericBaseFeature; +template class OnlineGenericBaseFeature; +template class OnlineGenericBaseFeature; + +} // namespace kaldifeat diff --git a/kaldifeat/csrc/online-feature.h b/kaldifeat/csrc/online-feature.h new file mode 100644 index 0000000..f234b5c --- /dev/null +++ b/kaldifeat/csrc/online-feature.h @@ -0,0 +1,127 @@ +// kaldifeat/csrc/online-feature.h +// +// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +// This file is copied/modified from kaldi/src/feat/online-feature.h + +#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_H_ +#define KALDIFEAT_CSRC_ONLINE_FEATURE_H_ + +#include + +#include "kaldifeat/csrc/feature-fbank.h" +#include "kaldifeat/csrc/feature-mfcc.h" +#include "kaldifeat/csrc/feature-plp.h" +#include "kaldifeat/csrc/feature-window.h" +#include "kaldifeat/csrc/online-feature-itf.h" + +namespace kaldifeat { + +/// This class serves as a storage for feature vectors with an option to limit +/// the memory usage by removing old elements. The deleted frames indices are +/// "remembered" so that regardless of the MAX_ITEMS setting, the user always +/// provides the indices as if no deletion was being performed. +/// This is useful when processing very long recordings which would otherwise +/// cause the memory to eventually blow up when the features are not being +/// removed. +class RecyclingVector { + public: + /// By default it does not remove any elements. + explicit RecyclingVector(int32_t items_to_hold = -1); + + ~RecyclingVector() = default; + RecyclingVector(const RecyclingVector &) = delete; + RecyclingVector &operator=(const RecyclingVector &) = delete; + + torch::Tensor At(int32_t index) const; + + void PushBack(torch::Tensor item); + + /// This method returns the size as if no "recycling" had happened, + /// i.e. equivalent to the number of times the PushBack method has been + /// called. + int32_t Size() const; + + private: + std::deque items_; + int32_t items_to_hold_; + int32_t first_available_index_; +}; + +/// This is a templated class for online feature extraction; +/// it's templated on a class like MfccComputer or PlpComputer +/// that does the basic feature extraction. +template +class OnlineGenericBaseFeature : public OnlineFeatureInterface { + public: + // Constructor from options class + explicit OnlineGenericBaseFeature(const typename C::Options &opts); + + int32_t Dim() const override { return computer_.Dim(); } + + float FrameShiftInSeconds() const override { + return computer_.GetFrameOptions().frame_shift_ms / 1000.0f; + } + + int32_t NumFramesReady() const override { return features_.Size(); } + + // Note: IsLastFrame() will only ever return true if you have called + // InputFinished() (and this frame is the last frame). + bool IsLastFrame(int32_t frame) const override { + return input_finished_ && frame == NumFramesReady() - 1; + } + + torch::Tensor GetFrame(int32_t frame) override { return features_.At(frame); } + + // This would be called from the application, when you get + // more wave data. Note: the sampling_rate is only provided so + // the code can assert that it matches the sampling rate + // expected in the options. + void AcceptWaveform(float sampling_rate, + const torch::Tensor &waveform) override; + + // InputFinished() tells the class you won't be providing any + // more waveform. This will help flush out the last frame or two + // of features, in the case where snip-edges == false; it also + // affects the return value of IsLastFrame(). + void InputFinished() override; + + private: + // This function computes any additional feature frames that it is possible to + // compute from 'waveform_remainder_', which at this point may contain more + // than just a remainder-sized quantity (because AcceptWaveform() appends to + // waveform_remainder_ before calling this function). It adds these feature + // frames to features_, and shifts off any now-unneeded samples of input from + // waveform_remainder_ while incrementing waveform_offset_ by the same amount. + void ComputeFeatures(); + + C computer_; // class that does the MFCC or PLP or filterbank computation + + FeatureWindowFunction window_function_; + + // features_ is the Mfcc or Plp or Fbank features that we have already + // computed. + + RecyclingVector features_; + + // True if the user has called "InputFinished()" + bool input_finished_; + + // waveform_offset_ is the number of samples of waveform that we have + // already discarded, i.e. that were prior to 'waveform_remainder_'. + int64_t waveform_offset_; + + // waveform_remainder_ is a short piece of waveform that we may need to keep + // after extracting all the whole frames we can (whatever length of feature + // will be required for the next phase of computation). + // It is a 1-D tensor + torch::Tensor waveform_remainder_; +}; + +using OnlineMfcc = OnlineGenericBaseFeature; +using OnlinePlp = OnlineGenericBaseFeature; +using OnlineFbank = OnlineGenericBaseFeature; + +} // namespace kaldifeat + +#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_H_ diff --git a/kaldifeat/python/csrc/CMakeLists.txt b/kaldifeat/python/csrc/CMakeLists.txt index affb69c..956263f 100644 --- a/kaldifeat/python/csrc/CMakeLists.txt +++ b/kaldifeat/python/csrc/CMakeLists.txt @@ -7,6 +7,7 @@ pybind11_add_module(_kaldifeat feature-window.cc kaldifeat.cc mel-computations.cc + online-feature.cc utils.cc ) target_link_libraries(_kaldifeat PRIVATE kaldifeat_core) diff --git a/kaldifeat/python/csrc/feature-window.cc b/kaldifeat/python/csrc/feature-window.cc index 0fd2ea8..5abaf36 100644 --- a/kaldifeat/python/csrc/feature-window.cc +++ b/kaldifeat/python/csrc/feature-window.cc @@ -25,6 +25,7 @@ static void PybindFrameExtractionOptions(py::module &m) { .def_readwrite("round_to_power_of_two", &PyClass::round_to_power_of_two) .def_readwrite("blackman_coeff", &PyClass::blackman_coeff) .def_readwrite("snip_edges", &PyClass::snip_edges) + .def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors) .def("as_dict", [](const PyClass &self) -> py::dict { return AsDict(self); }) .def_static("from_dict", @@ -35,8 +36,6 @@ static void PybindFrameExtractionOptions(py::module &m) { .def_readwrite("allow_downsample", &PyClass::allow_downsample) .def_readwrite("allow_upsample", &PyClass::allow_upsample) - .def_readwrite("max_feature_vectors", - &PyClass::max_feature_vectors) #endif .def("__str__", [](const PyClass &self) -> std::string { return self.ToString(); }) diff --git a/kaldifeat/python/csrc/kaldifeat.cc b/kaldifeat/python/csrc/kaldifeat.cc index 93e66ac..0a4b8c2 100644 --- a/kaldifeat/python/csrc/kaldifeat.cc +++ b/kaldifeat/python/csrc/kaldifeat.cc @@ -11,6 +11,7 @@ #include "kaldifeat/python/csrc/feature-spectrogram.h" #include "kaldifeat/python/csrc/feature-window.h" #include "kaldifeat/python/csrc/mel-computations.h" +#include "kaldifeat/python/csrc/online-feature.h" #include "torch/torch.h" namespace kaldifeat { @@ -24,6 +25,7 @@ PYBIND11_MODULE(_kaldifeat, m) { PybindFeatureMfcc(m); PybindFeaturePlp(m); PybindFeatureSpectrogram(m); + PybindOnlineFeature(m); } } // namespace kaldifeat diff --git a/kaldifeat/python/csrc/online-feature.cc b/kaldifeat/python/csrc/online-feature.cc new file mode 100644 index 0000000..13e4a4f --- /dev/null +++ b/kaldifeat/python/csrc/online-feature.cc @@ -0,0 +1,37 @@ +// kaldifeat/python/csrc/online-feature.cc +// +// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +#include "kaldifeat/python/csrc/online-feature.h" + +#include + +#include "kaldifeat/csrc/online-feature.h" +namespace kaldifeat { + +template +void PybindOnlineFeatureTpl(py::module &m, const std::string &class_name, + const std::string &class_help_doc = "") { + using PyClass = OnlineGenericBaseFeature; + using Options = typename C::Options; + py::class_(m, class_name.c_str(), class_help_doc.c_str()) + .def(py::init(), py::arg("opts")) + .def_property_readonly("dim", &PyClass::Dim) + .def_property_readonly("frame_shift_in_seconds", + &PyClass::FrameShiftInSeconds) + .def_property_readonly("num_frames_ready", &PyClass::NumFramesReady) + .def("is_last_frame", &PyClass::IsLastFrame, py::arg("frame")) + .def("get_frame", &PyClass::GetFrame, py::arg("frame")) + .def("get_frames", &PyClass::GetFrames, py::arg("frames")) + .def("accept_waveform", &PyClass::AcceptWaveform, + py::arg("sampling_rate"), py::arg("waveform")) + .def("input_finished", &PyClass::InputFinished); +} + +void PybindOnlineFeature(py::module &m) { + PybindOnlineFeatureTpl(m, "OnlineMfcc"); + PybindOnlineFeatureTpl(m, "OnlineFbank"); + PybindOnlineFeatureTpl(m, "OnlinePlp"); +} + +} // namespace kaldifeat diff --git a/kaldifeat/python/csrc/online-feature.h b/kaldifeat/python/csrc/online-feature.h new file mode 100644 index 0000000..c363f42 --- /dev/null +++ b/kaldifeat/python/csrc/online-feature.h @@ -0,0 +1,16 @@ +// kaldifeat/python/csrc/online-feature.h +// +// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +#ifndef KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_ +#define KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_ + +#include "kaldifeat/python/csrc/kaldifeat.h" + +namespace kaldifeat { + +void PybindOnlineFeature(py::module &m); + +} // namespace kaldifeat + +#endif // KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_ diff --git a/kaldifeat/python/csrc/utils.cc b/kaldifeat/python/csrc/utils.cc index 76f47aa..0259afc 100644 --- a/kaldifeat/python/csrc/utils.cc +++ b/kaldifeat/python/csrc/utils.cc @@ -30,6 +30,7 @@ FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict) { FROM_DICT(bool_, round_to_power_of_two); FROM_DICT(float_, blackman_coeff); FROM_DICT(bool_, snip_edges); + FROM_DICT(int_, max_feature_vectors); return opts; } @@ -47,6 +48,7 @@ py::dict AsDict(const FrameExtractionOptions &opts) { AS_DICT(round_to_power_of_two); AS_DICT(blackman_coeff); AS_DICT(snip_edges); + AS_DICT(max_feature_vectors); return dict; } diff --git a/kaldifeat/python/kaldifeat/__init__.py b/kaldifeat/python/kaldifeat/__init__.py index 6b2f088..60c6443 100644 --- a/kaldifeat/python/kaldifeat/__init__.py +++ b/kaldifeat/python/kaldifeat/__init__.py @@ -8,7 +8,7 @@ from _kaldifeat import ( SpectrogramOptions, ) -from .fbank import Fbank -from .mfcc import Mfcc -from .plp import Plp +from .fbank import Fbank, OnlineFbank +from .mfcc import Mfcc, OnlineMfcc +from .plp import OnlinePlp, Plp from .spectrogram import Spectrogram diff --git a/kaldifeat/python/kaldifeat/fbank.py b/kaldifeat/python/kaldifeat/fbank.py index 8f73911..45bc3ef 100644 --- a/kaldifeat/python/kaldifeat/fbank.py +++ b/kaldifeat/python/kaldifeat/fbank.py @@ -4,9 +4,20 @@ import _kaldifeat from .offline_feature import OfflineFeature +from .online_feature import OnlineFeature class Fbank(OfflineFeature): def __init__(self, opts: _kaldifeat.FbankOptions): super().__init__(opts) self.computer = _kaldifeat.Fbank(opts) + + +class OnlineFbank(OnlineFeature): + def __init__(self, opts: _kaldifeat.FbankOptions): + super().__init__(opts) + self.computer = _kaldifeat.OnlineFbank(opts) + + def __setstate__(self, state): + self.opts = _kaldifeat.FbankOptions.from_dict(state) + self.computer = _kaldifeat.OnlineFbank(self.opts) diff --git a/kaldifeat/python/kaldifeat/mfcc.py b/kaldifeat/python/kaldifeat/mfcc.py index fa1e225..f76f2f4 100644 --- a/kaldifeat/python/kaldifeat/mfcc.py +++ b/kaldifeat/python/kaldifeat/mfcc.py @@ -4,9 +4,20 @@ import _kaldifeat from .offline_feature import OfflineFeature +from .online_feature import OnlineFeature class Mfcc(OfflineFeature): def __init__(self, opts: _kaldifeat.MfccOptions): super().__init__(opts) self.computer = _kaldifeat.Mfcc(opts) + + +class OnlineMfcc(OnlineFeature): + def __init__(self, opts: _kaldifeat.MfccOptions): + super().__init__(opts) + self.computer = _kaldifeat.OnlineMfcc(opts) + + def __setstate__(self, state): + self.opts = _kaldifeat.MfccOptions.from_dict(state) + self.computer = _kaldifeat.OnlineMfcc(self.opts) diff --git a/kaldifeat/python/kaldifeat/online_feature.py b/kaldifeat/python/kaldifeat/online_feature.py new file mode 100644 index 0000000..cf687cf --- /dev/null +++ b/kaldifeat/python/kaldifeat/online_feature.py @@ -0,0 +1,95 @@ +# Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +from typing import List + +import torch + + +class OnlineFeature(object): + """Offline feature is a base class of other feature computers, + e.g., Fbank, Mfcc. + + This class has two fields: + + (1) opts. It contains the options for the feature computer. + (2) computer. The actual feature computer. It should be + instantiated by subclasses. + + Caution: + It supports only CPU at present. + """ + + def __init__(self, opts): + assert opts.device.type == "cpu" + + self.opts = opts + + # self.computer is expected to be set by subclasses + self.computer = None + + @property + def num_frames_ready(self) -> int: + """Return the number of ready frames. + + It can be updated by :method:`accept_waveform`. + + Note: + If you set ``opts.frame_opts.max_feature_vectors``, then + the valid frame indexes are in the range. + ``[num_frames_ready - max_feature_vectors, num_frames_ready)`` + + If you leave ``opts.frame_opts.max_feature_vectors`` to its default + value, then the range is ``[0, num_frames_ready)`` + """ + return self.computer.num_frames_ready + + def is_last_frame(self, frame: int) -> bool: + """Return True if the given frame is the last frame.""" + return self.computer.is_last_frame(frame) + + def get_frame(self, frame: int) -> torch.Tensor: + """Get the frame by its index. + Args: + frame: + The frame index. If ``opts.frame_opts.max_feature_vectors`` is + -1, then its valid values are in the range + ``[0, num_frames_ready)``. Otherwise, the range is + ``[num_frames_ready - max_feature_vectors, num_frames_ready)``. + Returns: + Return a 2-D tensor with shape ``(1, feature_dim)`` + """ + return self.computer.get_frame(frame) + + def get_frames(self, frames: List[int]) -> List[torch.Tensor]: + """Get frames at the given frame indexes. + Args: + frames: + Frames whose indexes are in this list are returned. + Returns: + Return a list of feature frames at the given indexes. + """ + return self.computer.get_frames(frames) + + def accept_waveform( + self, sampling_rate: float, waveform: torch.Tensor + ) -> None: + """Send audio samples to the extractor. + Args: + sampling_rate: + The sampling rate of the given audio samples. It has to be equal + to ``opts.frame_opts.samp_freq``. + waveform: + A 1-D tensor of shape (num_samples,). Its dtype is torch.float32 + and has to be on CPU. + """ + self.computer.accept_waveform(sampling_rate, waveform) + + def input_finished(self) -> None: + """Tell the extractor that no more audio samples will be available. + After calling this function, you cannot invoke ``accept_waveform`` + again. + """ + self.computer.input_finished() + + def __getstate__(self): + return self.opts.as_dict() diff --git a/kaldifeat/python/kaldifeat/plp.py b/kaldifeat/python/kaldifeat/plp.py index 219e2d4..d99dbc2 100644 --- a/kaldifeat/python/kaldifeat/plp.py +++ b/kaldifeat/python/kaldifeat/plp.py @@ -4,9 +4,20 @@ import _kaldifeat from .offline_feature import OfflineFeature +from .online_feature import OnlineFeature class Plp(OfflineFeature): def __init__(self, opts: _kaldifeat.PlpOptions): super().__init__(opts) self.computer = _kaldifeat.Plp(opts) + + +class OnlinePlp(OnlineFeature): + def __init__(self, opts: _kaldifeat.PlpOptions): + super().__init__(opts) + self.computer = _kaldifeat.OnlinePlp(opts) + + def __setstate__(self, state): + self.opts = _kaldifeat.PlpOptions.from_dict(state) + self.computer = _kaldifeat.OnlinePlp(self.opts) diff --git a/kaldifeat/python/tests/test_fbank.py b/kaldifeat/python/tests/test_fbank.py index 57092b6..1c06438 100755 --- a/kaldifeat/python/tests/test_fbank.py +++ b/kaldifeat/python/tests/test_fbank.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang) import pickle from pathlib import Path @@ -13,30 +13,88 @@ import kaldifeat cur_dir = Path(__file__).resolve().parent +def test_online_fbank( + opts: kaldifeat.FbankOptions, + wave: torch.Tensor, + cpu_features: torch.Tensor, +): + """ + Args: + opts: + The options to create the online fbank extractor. + wave: + The input 1-D waveform. + cpu_features: + The groud truth features that are computed offline + """ + online_fbank = kaldifeat.OnlineFbank(opts) + + num_processed_frames = 0 + i = 0 # current sample index to feed + while not online_fbank.is_last_frame(num_processed_frames - 1): + while num_processed_frames < online_fbank.num_frames_ready: + # There are new frames to be processed + frame = online_fbank.get_frame(num_processed_frames) + assert torch.allclose( + frame.squeeze(0), cpu_features[num_processed_frames] + ) + num_processed_frames += 1 + + # Simulate streaming . Send a random number of audio samples + # to the extractor + num_samples = torch.randint(300, 1000, (1,)).item() + + samples = wave[i : (i + num_samples)] # noqa + i += num_samples + if len(samples) == 0: + online_fbank.input_finished() + continue + + online_fbank.accept_waveform(16000, samples) + + assert num_processed_frames == online_fbank.num_frames_ready + assert num_processed_frames == cpu_features.size(0) + + def test_fbank_default(): print("=====test_fbank_default=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.FbankOptions() opts.device = device opts.frame_opts.dither = 0 fbank = kaldifeat.Fbank(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename) features = fbank(wave) assert features.device.type == "cpu" - gt = read_ark_txt(cur_dir / "test_data/test.txt") assert torch.allclose(features, gt, rtol=1e-1) + if cpu_features is None: + cpu_features = features - wave = wave.to(device) - features = fbank(wave) + features = fbank(wave.to(device)) assert features.device == device assert torch.allclose(features.cpu(), gt, rtol=1e-1) + # Now for online fbank + opts = kaldifeat.FbankOptions() + opts.frame_opts.dither = 0 + opts.frame_opts.max_feature_vectors = 100 + + test_online_fbank(opts, wave, cpu_features) + def test_fbank_htk(): print("=====test_fbank_htk=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-htk.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.FbankOptions() @@ -46,22 +104,32 @@ def test_fbank_htk(): opts.htk_compat = True fbank = kaldifeat.Fbank(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename) features = fbank(wave) assert features.device.type == "cpu" - gt = read_ark_txt(cur_dir / "test_data/test-htk.txt") assert torch.allclose(features, gt, rtol=1e-1) + if cpu_features is None: + cpu_features = features - wave = wave.to(device) - features = fbank(wave) + features = fbank(wave.to(device)) assert features.device == device assert torch.allclose(features.cpu(), gt, rtol=1e-1) + opts = kaldifeat.FbankOptions() + opts.frame_opts.dither = 0 + opts.use_energy = True + opts.htk_compat = True + + test_online_fbank(opts, wave, cpu_features) + def test_fbank_with_energy(): print("=====test_fbank_with_energy=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.FbankOptions() @@ -70,22 +138,31 @@ def test_fbank_with_energy(): opts.use_energy = True fbank = kaldifeat.Fbank(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename) features = fbank(wave) - gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt") assert torch.allclose(features, gt, rtol=1e-1) assert features.device.type == "cpu" + if cpu_features is None: + cpu_features = features - wave = wave.to(device) - features = fbank(wave) + features = fbank(wave.to(device)) assert features.device == device assert torch.allclose(features.cpu(), gt, rtol=1e-1) + opts = kaldifeat.FbankOptions() + opts.frame_opts.dither = 0 + opts.use_energy = True + + test_online_fbank(opts, wave, cpu_features) + def test_fbank_40_bins(): print("=====test_fbank_40_bins=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-40.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.FbankOptions() @@ -94,22 +171,31 @@ def test_fbank_40_bins(): opts.mel_opts.num_bins = 40 fbank = kaldifeat.Fbank(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename) features = fbank(wave) assert features.device.type == "cpu" - gt = read_ark_txt(cur_dir / "test_data/test-40.txt") assert torch.allclose(features, gt, rtol=1e-1) + if cpu_features is None: + cpu_features = features - wave = wave.to(device) - features = fbank(wave) + features = fbank(wave.to(device)) assert features.device == device assert torch.allclose(features.cpu(), gt, rtol=1e-1) + opts = kaldifeat.FbankOptions() + opts.frame_opts.dither = 0 + opts.mel_opts.num_bins = 40 + + test_online_fbank(opts, wave, cpu_features) + def test_fbank_40_bins_no_snip_edges(): print("=====test_fbank_40_bins_no_snip_edges=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.FbankOptions() @@ -119,19 +205,24 @@ def test_fbank_40_bins_no_snip_edges(): opts.frame_opts.snip_edges = False fbank = kaldifeat.Fbank(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename) features = fbank(wave) assert features.device.type == "cpu" - gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt") assert torch.allclose(features, gt, rtol=1e-1) + if cpu_features is None: + cpu_features = features - wave = wave.to(device) - features = fbank(wave) + features = fbank(wave.to(device)) assert features.device == device assert torch.allclose(features.cpu(), gt, rtol=1e-1) + opts = kaldifeat.FbankOptions() + opts.frame_opts.dither = 0 + opts.mel_opts.num_bins = 40 + opts.frame_opts.snip_edges = False + + test_online_fbank(opts, wave, cpu_features) + def test_fbank_chunk(): print("=====test_fbank_chunk=====") @@ -223,6 +314,16 @@ def test_pickle(): assert str(fbank.opts) == str(fbank2.opts) + opts = kaldifeat.FbankOptions() + opts.use_energy = True + opts.use_power = False + + fbank = kaldifeat.OnlineFbank(opts) + data = pickle.dumps(fbank) + fbank2 = pickle.loads(data) + + assert str(fbank.opts) == str(fbank2.opts) + if __name__ == "__main__": test_fbank_default() diff --git a/kaldifeat/python/tests/test_mfcc.py b/kaldifeat/python/tests/test_mfcc.py index 33407b5..5665da4 100755 --- a/kaldifeat/python/tests/test_mfcc.py +++ b/kaldifeat/python/tests/test_mfcc.py @@ -13,24 +13,82 @@ import kaldifeat cur_dir = Path(__file__).resolve().parent +def test_online_mfcc( + opts: kaldifeat.MfccOptions, + wave: torch.Tensor, + cpu_features: torch.Tensor, +): + """ + Args: + opts: + The options to create the online mfcc extractor. + wave: + The input 1-D waveform. + cpu_features: + The groud truth features that are computed offline + """ + online_mfcc = kaldifeat.OnlineMfcc(opts) + + num_processed_frames = 0 + i = 0 # current sample index to feed + while not online_mfcc.is_last_frame(num_processed_frames - 1): + while num_processed_frames < online_mfcc.num_frames_ready: + # There are new frames to be processed + frame = online_mfcc.get_frame(num_processed_frames) + assert torch.allclose( + frame.squeeze(0), cpu_features[num_processed_frames], atol=1e-3 + ) + num_processed_frames += 1 + + # Simulate streaming . Send a random number of audio samples + # to the extractor + num_samples = torch.randint(300, 1000, (1,)).item() + + samples = wave[i : (i + num_samples)] # noqa + i += num_samples + if len(samples) == 0: + online_mfcc.input_finished() + continue + + online_mfcc.accept_waveform(16000, samples) + + assert num_processed_frames == online_mfcc.num_frames_ready + assert num_processed_frames == cpu_features.size(0) + + def test_mfcc_default(): print("=====test_mfcc_default=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-mfcc.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.MfccOptions() opts.device = device opts.frame_opts.dither = 0 mfcc = kaldifeat.Mfcc(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename).to(device) - features = mfcc(wave) - gt = read_ark_txt(cur_dir / "test_data/test-mfcc.txt") + features = mfcc(wave.to(device)) + if device.type == "cpu": + cpu_features = features + assert torch.allclose(features.cpu(), gt, atol=1e-1) + opts = kaldifeat.MfccOptions() + opts.frame_opts.dither = 0 + + test_online_mfcc(opts, wave, cpu_features) + def test_mfcc_no_snip_edges(): print("=====test_mfcc_no_snip_edges=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-mfcc-no-snip-edges.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.MfccOptions() @@ -39,13 +97,19 @@ def test_mfcc_no_snip_edges(): opts.frame_opts.snip_edges = False mfcc = kaldifeat.Mfcc(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename).to(device) - features = mfcc(wave) - gt = read_ark_txt(cur_dir / "test_data/test-mfcc-no-snip-edges.txt") + features = mfcc(wave.to(device)) + if device.type == "cpu": + cpu_features = features + assert torch.allclose(features.cpu(), gt, rtol=1e-1) + opts = kaldifeat.MfccOptions() + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + + test_online_mfcc(opts, wave, cpu_features) + def test_pickle(): for device in get_devices(): @@ -60,6 +124,16 @@ def test_pickle(): assert str(mfcc.opts) == str(mfcc2.opts) + opts = kaldifeat.MfccOptions() + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + + mfcc = kaldifeat.OnlineMfcc(opts) + data = pickle.dumps(mfcc) + mfcc2 = pickle.loads(data) + + assert str(mfcc.opts) == str(mfcc2.opts) + if __name__ == "__main__": test_mfcc_default() diff --git a/kaldifeat/python/tests/test_plp.py b/kaldifeat/python/tests/test_plp.py index 4f20452..cf56d41 100755 --- a/kaldifeat/python/tests/test_plp.py +++ b/kaldifeat/python/tests/test_plp.py @@ -13,24 +13,82 @@ import kaldifeat cur_dir = Path(__file__).resolve().parent +def test_online_plp( + opts: kaldifeat.PlpOptions, + wave: torch.Tensor, + cpu_features: torch.Tensor, +): + """ + Args: + opts: + The options to create the online plp extractor. + wave: + The input 1-D waveform. + cpu_features: + The groud truth features that are computed offline + """ + online_plp = kaldifeat.OnlinePlp(opts) + + num_processed_frames = 0 + i = 0 # current sample index to feed + while not online_plp.is_last_frame(num_processed_frames - 1): + while num_processed_frames < online_plp.num_frames_ready: + # There are new frames to be processed + frame = online_plp.get_frame(num_processed_frames) + assert torch.allclose( + frame.squeeze(0), cpu_features[num_processed_frames], atol=1e-3 + ) + num_processed_frames += 1 + + # Simulate streaming . Send a random number of audio samples + # to the extractor + num_samples = torch.randint(300, 1000, (1,)).item() + + samples = wave[i : (i + num_samples)] # noqa + i += num_samples + if len(samples) == 0: + online_plp.input_finished() + continue + + online_plp.accept_waveform(16000, samples) + + assert num_processed_frames == online_plp.num_frames_ready + assert num_processed_frames == cpu_features.size(0) + + def test_plp_default(): print("=====test_plp_default=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-plp.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.PlpOptions() opts.frame_opts.dither = 0 opts.device = device plp = kaldifeat.Plp(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename).to(device) - features = plp(wave) - gt = read_ark_txt(cur_dir / "test_data/test-plp.txt") + features = plp(wave.to(device)) + if device.type == "cpu": + cpu_features = features + assert torch.allclose(features.cpu(), gt, rtol=1e-1) + opts = kaldifeat.PlpOptions() + opts.frame_opts.dither = 0 + + test_online_plp(opts, wave, cpu_features) + def test_plp_no_snip_edges(): print("=====test_plp_no_snip_edges=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-plp-no-snip-edges.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.PlpOptions() @@ -39,16 +97,26 @@ def test_plp_no_snip_edges(): opts.frame_opts.snip_edges = False plp = kaldifeat.Plp(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename).to(device) - features = plp(wave) - gt = read_ark_txt(cur_dir / "test_data/test-plp-no-snip-edges.txt") + features = plp(wave.to(device)) + if device.type == "cpu": + cpu_features = features assert torch.allclose(features.cpu(), gt, atol=1e-1) + opts = kaldifeat.PlpOptions() + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + + test_online_plp(opts, wave, cpu_features) + def test_plp_htk_10_ceps(): print("=====test_plp_htk_10_ceps=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-plp-htk-10-ceps.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.PlpOptions() @@ -58,13 +126,19 @@ def test_plp_htk_10_ceps(): opts.frame_opts.dither = 0 plp = kaldifeat.Plp(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename).to(device) - features = plp(wave) - gt = read_ark_txt(cur_dir / "test_data/test-plp-htk-10-ceps.txt") + features = plp(wave.to(device)) + if device.type == "cpu": + cpu_features = features assert torch.allclose(features.cpu(), gt, atol=1e-1) + opts = kaldifeat.PlpOptions() + opts.htk_compat = True + opts.num_ceps = 10 + opts.frame_opts.dither = 0 + + test_online_plp(opts, wave, cpu_features) + def test_pickle(): for device in get_devices(): @@ -79,6 +153,16 @@ def test_pickle(): assert str(plp.opts) == str(plp2.opts) + opts = kaldifeat.PlpOptions() + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + + plp = kaldifeat.OnlinePlp(opts) + data = pickle.dumps(plp) + plp2 = pickle.loads(data) + + assert str(plp.opts) == str(plp2.opts) + if __name__ == "__main__": test_plp_default()