From 2399cc8993e8967135306237c960bea2da8b7397 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 2 Apr 2022 12:07:59 +0800 Subject: [PATCH] Start to add streaming feature extractors. --- kaldifeat/csrc/online-feature-itf.h | 85 +++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 kaldifeat/csrc/online-feature-itf.h diff --git a/kaldifeat/csrc/online-feature-itf.h b/kaldifeat/csrc/online-feature-itf.h new file mode 100644 index 0000000..240612a --- /dev/null +++ b/kaldifeat/csrc/online-feature-itf.h @@ -0,0 +1,85 @@ +// kaldifeat/csrc/online-feature-itf.h +// +// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +// This file is copied/modified from kaldi/src/itf/online-feature-itf.h + +#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_ +#define KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_ + +#include + +#include "torch/script.h" + +namespace kaldifeat { + +class OnlineFeatureInterface { + public: + virtual ~OnlineFeatureInterface() = default; + + virtual int32_t Dim() const = 0; /// returns the feature dimension. + + /// Returns the total number of frames, since the start of the utterance, that + /// are now available. In an online-decoding context, this will likely + /// increase with time as more data becomes available. + virtual int32_t NumFramesReady() const = 0; + + /// Returns true if this is the last frame. Frame indices are zero-based, so + /// the first frame is zero. IsLastFrame(-1) will return false, unless the + /// file is empty (which is a case that I'm not sure all the code will handle, + /// so be careful). This function may return false for some frame if we + /// haven't yet decided to terminate decoding, but later true if we decide to + /// terminate decoding. This function exists mainly to correctly handle end + /// effects in feature extraction, and is not a mechanism to determine how + /// many frames are in the decodable object (as it used to be, and for + /// backward compatibility, still is, in the Decodable interface). + virtual bool IsLastFrame(int32_t frame) const = 0; + + /// Gets the feature vector for this frame. Before calling this for a given + /// frame, it is assumed that you called NumFramesReady() and it returned a + /// number greater than "frame". Otherwise this call will likely crash with + /// an assert failure. This function is not declared const, in case there is + /// some kind of caching going on, but most of the time it shouldn't modify + /// the class. + /// + /// The returned tensor has shape (1, Dim()). + virtual torch::Tensor GetFrame(int32_t frame) = 0; + + /// This is like GetFrame() but for a collection of frames. There is a + /// default implementation that just gets the frames one by one, but it + /// may be overridden for efficiency by child classes (since sometimes + /// it's more efficient to do things in a batch). + /// + /// The returned tensor has shape (frames.size(), Dim()). + virtual torch::Tensor GetFrames(const std::vector &frames) { + std::vector features; + features.reserve(frames.size()); + + for (auto i : frames) { + torch::Tensor f = GetFrame(i); + features.push_back(std::move(f)); + } + + return torch::cat(features, /*dim*/ 0); + } + + // Returns frame shift in seconds. Helps to estimate duration from frame + // counts. + virtual float FrameShiftInSeconds() const = 0; + + /// This would be called from the application, when you get more wave data. + /// Note: the sampling_rate is typically only provided so the code can assert + /// that it matches the sampling rate expected in the options. + virtual void AcceptWaveform(float sampling_rate, + const torch::Tensor &waveform) = 0; + + /// InputFinished() tells the class you won't be providing any + /// more waveform. This will help flush out the last few frames + /// of delta or LDA features (it will typically affect the return value + /// of IsLastFrame. + virtual void InputFinished() = 0; +}; + +} // namespace kaldifeat + +#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_