mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 18:12:17 +00:00
Start to add streaming feature extractors.
This commit is contained in:
parent
4aab351344
commit
2399cc8993
85
kaldifeat/csrc/online-feature-itf.h
Normal file
85
kaldifeat/csrc/online-feature-itf.h
Normal file
@ -0,0 +1,85 @@
|
||||
// kaldifeat/csrc/online-feature-itf.h
|
||||
//
|
||||
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
// This file is copied/modified from kaldi/src/itf/online-feature-itf.h
|
||||
|
||||
#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
|
||||
#define KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "torch/script.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
class OnlineFeatureInterface {
|
||||
public:
|
||||
virtual ~OnlineFeatureInterface() = default;
|
||||
|
||||
virtual int32_t Dim() const = 0; /// returns the feature dimension.
|
||||
|
||||
/// Returns the total number of frames, since the start of the utterance, that
|
||||
/// are now available. In an online-decoding context, this will likely
|
||||
/// increase with time as more data becomes available.
|
||||
virtual int32_t NumFramesReady() const = 0;
|
||||
|
||||
/// Returns true if this is the last frame. Frame indices are zero-based, so
|
||||
/// the first frame is zero. IsLastFrame(-1) will return false, unless the
|
||||
/// file is empty (which is a case that I'm not sure all the code will handle,
|
||||
/// so be careful). This function may return false for some frame if we
|
||||
/// haven't yet decided to terminate decoding, but later true if we decide to
|
||||
/// terminate decoding. This function exists mainly to correctly handle end
|
||||
/// effects in feature extraction, and is not a mechanism to determine how
|
||||
/// many frames are in the decodable object (as it used to be, and for
|
||||
/// backward compatibility, still is, in the Decodable interface).
|
||||
virtual bool IsLastFrame(int32_t frame) const = 0;
|
||||
|
||||
/// Gets the feature vector for this frame. Before calling this for a given
|
||||
/// frame, it is assumed that you called NumFramesReady() and it returned a
|
||||
/// number greater than "frame". Otherwise this call will likely crash with
|
||||
/// an assert failure. This function is not declared const, in case there is
|
||||
/// some kind of caching going on, but most of the time it shouldn't modify
|
||||
/// the class.
|
||||
///
|
||||
/// The returned tensor has shape (1, Dim()).
|
||||
virtual torch::Tensor GetFrame(int32_t frame) = 0;
|
||||
|
||||
/// This is like GetFrame() but for a collection of frames. There is a
|
||||
/// default implementation that just gets the frames one by one, but it
|
||||
/// may be overridden for efficiency by child classes (since sometimes
|
||||
/// it's more efficient to do things in a batch).
|
||||
///
|
||||
/// The returned tensor has shape (frames.size(), Dim()).
|
||||
virtual torch::Tensor GetFrames(const std::vector<int32_t> &frames) {
|
||||
std::vector<torch::Tensor> features;
|
||||
features.reserve(frames.size());
|
||||
|
||||
for (auto i : frames) {
|
||||
torch::Tensor f = GetFrame(i);
|
||||
features.push_back(std::move(f));
|
||||
}
|
||||
|
||||
return torch::cat(features, /*dim*/ 0);
|
||||
}
|
||||
|
||||
// Returns frame shift in seconds. Helps to estimate duration from frame
|
||||
// counts.
|
||||
virtual float FrameShiftInSeconds() const = 0;
|
||||
|
||||
/// This would be called from the application, when you get more wave data.
|
||||
/// Note: the sampling_rate is typically only provided so the code can assert
|
||||
/// that it matches the sampling rate expected in the options.
|
||||
virtual void AcceptWaveform(float sampling_rate,
|
||||
const torch::Tensor &waveform) = 0;
|
||||
|
||||
/// InputFinished() tells the class you won't be providing any
|
||||
/// more waveform. This will help flush out the last few frames
|
||||
/// of delta or LDA features (it will typically affect the return value
|
||||
/// of IsLastFrame.
|
||||
virtual void InputFinished() = 0;
|
||||
};
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
|
Loading…
x
Reference in New Issue
Block a user