diff --git a/kaldifeat/csrc/feature-common.h b/kaldifeat/csrc/feature-common.h index 5710c22..24e7ec1 100644 --- a/kaldifeat/csrc/feature-common.h +++ b/kaldifeat/csrc/feature-common.h @@ -62,6 +62,10 @@ class OfflineFeatureTpl { int32_t Dim() const { return computer_.Dim(); } const Options &GetOptions() const { return computer_.GetOptions(); } + const FrameExtractionOptions &GetFrameOptions() const { + return GetOptions().frame_opts; + } + // Copy constructor. OfflineFeatureTpl(const OfflineFeatureTpl &) = delete; OfflineFeatureTpl &operator=(const OfflineFeatureTpl &) = delete; diff --git a/kaldifeat/csrc/feature-window.cc b/kaldifeat/csrc/feature-window.cc index 4ec5ac2..5d25720 100644 --- a/kaldifeat/csrc/feature-window.cc +++ b/kaldifeat/csrc/feature-window.cc @@ -161,19 +161,20 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value) { #if 1 return wave + rand_gauss * dither_value; #else - // use in-place version of wave and change its to pointer type + // use in-place version of wave and change it to pointer type wave_->add_(rand_gauss, dither_value); #endif } torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) { - using namespace torch::indexing; // It imports: Slice, None // NOLINT if (preemph_coeff == 0.0f) return wave; KALDIFEAT_ASSERT(preemph_coeff >= 0.0f && preemph_coeff <= 1.0f); torch::Tensor ans = torch::empty_like(wave); + using torch::indexing::None; + using torch::indexing::Slice; // right = wave[:, 1:] torch::Tensor right = wave.index({"...", Slice(1, None, None)}); @@ -188,4 +189,58 @@ torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) { return ans; } +torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave, + int32_t f, const FrameExtractionOptions &opts) { + KALDIFEAT_ASSERT(sample_offset >= 0 && wave.numel() != 0); + + int32_t frame_length = opts.WindowSize(); + int64_t num_samples = sample_offset + wave.numel(); + int64_t start_sample = FirstSampleOfFrame(f, opts); + int64_t end_sample = start_sample + frame_length; + + if (opts.snip_edges) { + KALDIFEAT_ASSERT(start_sample >= sample_offset && + end_sample <= num_samples); + } else { + KALDIFEAT_ASSERT(sample_offset == 0 || start_sample >= sample_offset); + } + + // wave_start and wave_end are start and end indexes into 'wave', for the + // piece of wave that we're trying to extract. + int32_t wave_start = static_cast(start_sample - sample_offset); + int32_t wave_end = wave_start + frame_length; + + if (wave_start >= 0 && wave_end <= wave.numel()) { + // the normal case -- no edge effects to consider. + // return wave[wave_start:wave_end] + return wave.index({torch::indexing::Slice(wave_start, wave_end)}); + } else { + torch::Tensor window = torch::empty({frame_length}, torch::kFloat); + auto p_window = window.accessor(); + auto p_wave = wave.accessor(); + + // Deal with any end effects by reflection, if needed. This code will only + // be reached for about two frames per utterance, so we don't concern + // ourselves excessively with efficiency. + int32_t wave_dim = wave.numel(); + for (int32_t s = 0; s != frame_length; ++s) { + int32_t s_in_wave = s + wave_start; + while (s_in_wave < 0 || s_in_wave >= wave_dim) { + // reflect around the beginning or end of the wave. + // e.g. -1 -> 0, -2 -> 1. + // dim -> dim - 1, dim + 1 -> dim - 2. + // the code supports repeated reflections, although this + // would only be needed in pathological cases. + if (s_in_wave < 0) { + s_in_wave = -s_in_wave - 1; + } else { + s_in_wave = 2 * wave_dim - 1 - s_in_wave; + } + } + + p_window[s] = p_wave[s_in_wave]; + } + } +} + } // namespace kaldifeat diff --git a/kaldifeat/csrc/feature-window.h b/kaldifeat/csrc/feature-window.h index 255c274..26d743a 100644 --- a/kaldifeat/csrc/feature-window.h +++ b/kaldifeat/csrc/feature-window.h @@ -44,7 +44,11 @@ struct FrameExtractionOptions { bool snip_edges = true; // bool allow_downsample = false; // bool allow_upsample = false; - // int32_t max_feature_vectors = -1; + + // Used for streaming feature extraction. It indicates the number + // of feature frames to keep in the recycling vector. -1 means to + // keep all feature frames. + int32_t max_feature_vectors = -1; int32_t WindowShift() const { return static_cast(samp_freq * 0.001f * frame_shift_ms); @@ -71,7 +75,7 @@ struct FrameExtractionOptions { KALDIFEAT_PRINT(snip_edges); // KALDIFEAT_PRINT(allow_downsample); // KALDIFEAT_PRINT(allow_upsample); - // KALDIFEAT_PRINT(max_feature_vectors); + KALDIFEAT_PRINT(max_feature_vectors); #undef KALDIFEAT_PRINT return os.str(); } @@ -100,11 +104,11 @@ class FeatureWindowFunction { @param [in] flush True if we are asserting that this number of samples is 'all there is', false if we expecting more data to possibly come in. This - only makes a difference to the answer if opts.snips_edges - == false. For offline feature extraction you always want flush == - true. In an online-decoding context, once you know (or decide) - that no more data is coming in, you'd call it with flush == true at the end - to flush out any remaining data. + only makes a difference to the answer + if opts.snips_edges== false. For offline feature extraction you always want + flush == true. In an online-decoding context, once you know (or decide) that + no more data is coming in, you'd call it with flush == true at the end to + flush out any remaining data. */ int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts, bool flush = true); @@ -133,6 +137,29 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value); torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave); +/* + ExtractWindow() extracts "frame_length" samples from the given waveform. + Note: This function only extracts "frame_length" samples + from the input waveform, without any further processing. + + @param [in] sample_offset If 'wave' is not the entire waveform, but + part of it to the left has been discarded, then the + number of samples prior to 'wave' that we have + already discarded. Set this to zero if you are + processing the entire waveform in one piece, or + if you get 'no matching function' compilation + errors when updating the code. + @param [in] wave The waveform + @param [in] f The frame index to be extracted, with + 0 <= f < NumFrames(sample_offset + wave.numel(), opts, true) + @param [in] opts The options class to be used + @return Return a tensor containing "frame_length" samples extracted from + `wave`, without any further processing. Its shape is + (1, frame_length). +*/ +torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave, + int32_t f, const FrameExtractionOptions &opts); + } // namespace kaldifeat #endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_ diff --git a/kaldifeat/csrc/online-feature-itf.h b/kaldifeat/csrc/online-feature-itf.h index 7599182..4dc779a 100644 --- a/kaldifeat/csrc/online-feature-itf.h +++ b/kaldifeat/csrc/online-feature-itf.h @@ -19,6 +19,10 @@ class OnlineFeatureInterface { virtual ~OnlineFeatureInterface() = default; virtual int32_t Dim() const = 0; /// returns the feature dimension. + // + // Returns frame shift in seconds. Helps to estimate duration from frame + // counts. + virtual float FrameShiftInSeconds() const = 0; /// Returns the total number of frames, since the start of the utterance, that /// are now available. In an online-decoding context, this will likely @@ -64,10 +68,6 @@ class OnlineFeatureInterface { return torch::cat(features, /*dim*/ 0); } - // Returns frame shift in seconds. Helps to estimate duration from frame - // counts. - virtual float FrameShiftInSeconds() const = 0; - /// This would be called from the application, when you get more wave data. /// Note: the sampling_rate is typically only provided so the code can assert /// that it matches the sampling rate expected in the options. diff --git a/kaldifeat/csrc/online-feature.cc b/kaldifeat/csrc/online-feature.cc index 0e2f250..42855f4 100644 --- a/kaldifeat/csrc/online-feature.cc +++ b/kaldifeat/csrc/online-feature.cc @@ -6,6 +6,7 @@ #include "kaldifeat/csrc/online-feature.h" +#include "kaldifeat/csrc/feature-window.h" #include "kaldifeat/csrc/log.h" namespace kaldifeat { @@ -40,4 +41,97 @@ int32_t RecyclingVector::Size() const { return first_available_index_ + static_cast(items_.size()); } +template +OnlineGenericBaseFeature::OnlineGenericBaseFeature( + const typename C::Options &opts) + : computer_(opts), + window_function_(opts.frame_opts, opts.device), + features_(opts.frame_opts.max_feature_vectors), + input_finished_(false), + waveform_offset_(0) { + // Casting to uint32_t, an unsigned type, means that -1 would be treated + // as `very large`. + KALDIFEAT_ASSERT(static_cast(opts.frame_opts.max_feature_vectors) > + 200); +} + +template +void OnlineGenericBaseFeature::AcceptWaveform( + float sampling_rate, const torch::Tensor &original_waveform) { + if (original_waveform.numel() == 0) return; // Nothing to do. + + KALDIFEAT_ASSERT(original_waveform.dim() == 1); + + if (input_finished_) + KALDIFEAT_ERR << "AcceptWaveform called after InputFinished() was called."; + + if (waveform_remainder_.numel() == 0) { + waveform_remainder_ = original_waveform; + } else { + waveform_remainder_ = + torch::cat({waveform_remainder_, original_waveform}, /*dim*/ 0); + } + + ComputeFeatures(); +} + +template +void OnlineGenericBaseFeature::InputFinished() { + input_finished_ = true; + ComputeFeatures(); +} + +template +void OnlineGenericBaseFeature::ComputeFeatures() { + const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions(); + + int64_t num_samples_total = waveform_offset_ + waveform_remainder_.numel(); + int32_t num_frames_old = features_.Size(); + int32_t num_frames_new = + NumFrames(num_samples_total, frame_opts, input_finished_); + + KALDIFEAT_ASSERT(num_frames_new >= num_frames_old); + + // note: this online feature-extraction code does not support VTLN. + float vtln_warp = 1.0; + + for (int32_t frame = num_frames_old; frame < num_frames_new; ++frame) { + torch::Tensor window = + ExtractWindow(waveform_offset_, waveform_remainder_, frame, frame_opts); + + // TODO(fangjun): We can compute all feature frames at once + torch::Tensor this_feature = + computer_.ComputeFeatures(window.unsqueeze(0), vtln_warp); + features_.PushBack(this_feature); + } + + // OK, we will now discard any portion of the signal that will not be + // necessary to compute frames in the future. + int64_t first_sample_of_next_frame = + FirstSampleOfFrame(num_frames_new, frame_opts); + int32_t samples_to_discard = first_sample_of_next_frame - waveform_offset_; + if (samples_to_discard > 0) { + // discard the leftmost part of the waveform that we no longer need. + int32_t new_num_samples = waveform_remainder_.numel() - samples_to_discard; + if (new_num_samples <= 0) { + // odd, but we'll try to handle it. + waveform_offset_ += waveform_remainder_.numel(); + waveform_remainder_.resize_({0}); + } else { + using torch::indexing::None; + using torch::indexing::Slice; + + waveform_remainder_ = + waveform_remainder_.index({Slice(samples_to_discard, None)}); + + waveform_offset_ += samples_to_discard; + } + } +} + +// instantiate the templates defined here for MFCC, PLP and filterbank classes. +template class OnlineGenericBaseFeature; +template class OnlineGenericBaseFeature; +template class OnlineGenericBaseFeature; + } // namespace kaldifeat diff --git a/kaldifeat/csrc/online-feature.h b/kaldifeat/csrc/online-feature.h index d64bb8f..f234b5c 100644 --- a/kaldifeat/csrc/online-feature.h +++ b/kaldifeat/csrc/online-feature.h @@ -9,6 +9,10 @@ #include +#include "kaldifeat/csrc/feature-fbank.h" +#include "kaldifeat/csrc/feature-mfcc.h" +#include "kaldifeat/csrc/feature-plp.h" +#include "kaldifeat/csrc/feature-window.h" #include "kaldifeat/csrc/online-feature-itf.h" namespace kaldifeat { @@ -44,6 +48,80 @@ class RecyclingVector { int32_t first_available_index_; }; +/// This is a templated class for online feature extraction; +/// it's templated on a class like MfccComputer or PlpComputer +/// that does the basic feature extraction. +template +class OnlineGenericBaseFeature : public OnlineFeatureInterface { + public: + // Constructor from options class + explicit OnlineGenericBaseFeature(const typename C::Options &opts); + + int32_t Dim() const override { return computer_.Dim(); } + + float FrameShiftInSeconds() const override { + return computer_.GetFrameOptions().frame_shift_ms / 1000.0f; + } + + int32_t NumFramesReady() const override { return features_.Size(); } + + // Note: IsLastFrame() will only ever return true if you have called + // InputFinished() (and this frame is the last frame). + bool IsLastFrame(int32_t frame) const override { + return input_finished_ && frame == NumFramesReady() - 1; + } + + torch::Tensor GetFrame(int32_t frame) override { return features_.At(frame); } + + // This would be called from the application, when you get + // more wave data. Note: the sampling_rate is only provided so + // the code can assert that it matches the sampling rate + // expected in the options. + void AcceptWaveform(float sampling_rate, + const torch::Tensor &waveform) override; + + // InputFinished() tells the class you won't be providing any + // more waveform. This will help flush out the last frame or two + // of features, in the case where snip-edges == false; it also + // affects the return value of IsLastFrame(). + void InputFinished() override; + + private: + // This function computes any additional feature frames that it is possible to + // compute from 'waveform_remainder_', which at this point may contain more + // than just a remainder-sized quantity (because AcceptWaveform() appends to + // waveform_remainder_ before calling this function). It adds these feature + // frames to features_, and shifts off any now-unneeded samples of input from + // waveform_remainder_ while incrementing waveform_offset_ by the same amount. + void ComputeFeatures(); + + C computer_; // class that does the MFCC or PLP or filterbank computation + + FeatureWindowFunction window_function_; + + // features_ is the Mfcc or Plp or Fbank features that we have already + // computed. + + RecyclingVector features_; + + // True if the user has called "InputFinished()" + bool input_finished_; + + // waveform_offset_ is the number of samples of waveform that we have + // already discarded, i.e. that were prior to 'waveform_remainder_'. + int64_t waveform_offset_; + + // waveform_remainder_ is a short piece of waveform that we may need to keep + // after extracting all the whole frames we can (whatever length of feature + // will be required for the next phase of computation). + // It is a 1-D tensor + torch::Tensor waveform_remainder_; +}; + +using OnlineMfcc = OnlineGenericBaseFeature; +using OnlinePlp = OnlineGenericBaseFeature; +using OnlineFbank = OnlineGenericBaseFeature; + } // namespace kaldifeat #endif // KALDIFEAT_CSRC_ONLINE_FEATURE_H_ diff --git a/kaldifeat/python/csrc/feature-window.cc b/kaldifeat/python/csrc/feature-window.cc index 0fd2ea8..5abaf36 100644 --- a/kaldifeat/python/csrc/feature-window.cc +++ b/kaldifeat/python/csrc/feature-window.cc @@ -25,6 +25,7 @@ static void PybindFrameExtractionOptions(py::module &m) { .def_readwrite("round_to_power_of_two", &PyClass::round_to_power_of_two) .def_readwrite("blackman_coeff", &PyClass::blackman_coeff) .def_readwrite("snip_edges", &PyClass::snip_edges) + .def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors) .def("as_dict", [](const PyClass &self) -> py::dict { return AsDict(self); }) .def_static("from_dict", @@ -35,8 +36,6 @@ static void PybindFrameExtractionOptions(py::module &m) { .def_readwrite("allow_downsample", &PyClass::allow_downsample) .def_readwrite("allow_upsample", &PyClass::allow_upsample) - .def_readwrite("max_feature_vectors", - &PyClass::max_feature_vectors) #endif .def("__str__", [](const PyClass &self) -> std::string { return self.ToString(); }) diff --git a/kaldifeat/python/csrc/utils.cc b/kaldifeat/python/csrc/utils.cc index 76f47aa..0259afc 100644 --- a/kaldifeat/python/csrc/utils.cc +++ b/kaldifeat/python/csrc/utils.cc @@ -30,6 +30,7 @@ FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict) { FROM_DICT(bool_, round_to_power_of_two); FROM_DICT(float_, blackman_coeff); FROM_DICT(bool_, snip_edges); + FROM_DICT(int_, max_feature_vectors); return opts; } @@ -47,6 +48,7 @@ py::dict AsDict(const FrameExtractionOptions &opts) { AS_DICT(round_to_power_of_two); AS_DICT(blackman_coeff); AS_DICT(snip_edges); + AS_DICT(max_feature_vectors); return dict; }