mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-10 18:42:17 +00:00
Add OnlineGenericBaseFeature.
This commit is contained in:
parent
34ba30272d
commit
039e27dd32
@ -62,6 +62,10 @@ class OfflineFeatureTpl {
|
|||||||
int32_t Dim() const { return computer_.Dim(); }
|
int32_t Dim() const { return computer_.Dim(); }
|
||||||
const Options &GetOptions() const { return computer_.GetOptions(); }
|
const Options &GetOptions() const { return computer_.GetOptions(); }
|
||||||
|
|
||||||
|
const FrameExtractionOptions &GetFrameOptions() const {
|
||||||
|
return GetOptions().frame_opts;
|
||||||
|
}
|
||||||
|
|
||||||
// Copy constructor.
|
// Copy constructor.
|
||||||
OfflineFeatureTpl(const OfflineFeatureTpl<F> &) = delete;
|
OfflineFeatureTpl(const OfflineFeatureTpl<F> &) = delete;
|
||||||
OfflineFeatureTpl<F> &operator=(const OfflineFeatureTpl<F> &) = delete;
|
OfflineFeatureTpl<F> &operator=(const OfflineFeatureTpl<F> &) = delete;
|
||||||
|
@ -161,19 +161,20 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value) {
|
|||||||
#if 1
|
#if 1
|
||||||
return wave + rand_gauss * dither_value;
|
return wave + rand_gauss * dither_value;
|
||||||
#else
|
#else
|
||||||
// use in-place version of wave and change its to pointer type
|
// use in-place version of wave and change it to pointer type
|
||||||
wave_->add_(rand_gauss, dither_value);
|
wave_->add_(rand_gauss, dither_value);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) {
|
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) {
|
||||||
using namespace torch::indexing; // It imports: Slice, None // NOLINT
|
|
||||||
if (preemph_coeff == 0.0f) return wave;
|
if (preemph_coeff == 0.0f) return wave;
|
||||||
|
|
||||||
KALDIFEAT_ASSERT(preemph_coeff >= 0.0f && preemph_coeff <= 1.0f);
|
KALDIFEAT_ASSERT(preemph_coeff >= 0.0f && preemph_coeff <= 1.0f);
|
||||||
|
|
||||||
torch::Tensor ans = torch::empty_like(wave);
|
torch::Tensor ans = torch::empty_like(wave);
|
||||||
|
|
||||||
|
using torch::indexing::None;
|
||||||
|
using torch::indexing::Slice;
|
||||||
// right = wave[:, 1:]
|
// right = wave[:, 1:]
|
||||||
torch::Tensor right = wave.index({"...", Slice(1, None, None)});
|
torch::Tensor right = wave.index({"...", Slice(1, None, None)});
|
||||||
|
|
||||||
@ -188,4 +189,58 @@ torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) {
|
|||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
|
||||||
|
int32_t f, const FrameExtractionOptions &opts) {
|
||||||
|
KALDIFEAT_ASSERT(sample_offset >= 0 && wave.numel() != 0);
|
||||||
|
|
||||||
|
int32_t frame_length = opts.WindowSize();
|
||||||
|
int64_t num_samples = sample_offset + wave.numel();
|
||||||
|
int64_t start_sample = FirstSampleOfFrame(f, opts);
|
||||||
|
int64_t end_sample = start_sample + frame_length;
|
||||||
|
|
||||||
|
if (opts.snip_edges) {
|
||||||
|
KALDIFEAT_ASSERT(start_sample >= sample_offset &&
|
||||||
|
end_sample <= num_samples);
|
||||||
|
} else {
|
||||||
|
KALDIFEAT_ASSERT(sample_offset == 0 || start_sample >= sample_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
// wave_start and wave_end are start and end indexes into 'wave', for the
|
||||||
|
// piece of wave that we're trying to extract.
|
||||||
|
int32_t wave_start = static_cast<int32_t>(start_sample - sample_offset);
|
||||||
|
int32_t wave_end = wave_start + frame_length;
|
||||||
|
|
||||||
|
if (wave_start >= 0 && wave_end <= wave.numel()) {
|
||||||
|
// the normal case -- no edge effects to consider.
|
||||||
|
// return wave[wave_start:wave_end]
|
||||||
|
return wave.index({torch::indexing::Slice(wave_start, wave_end)});
|
||||||
|
} else {
|
||||||
|
torch::Tensor window = torch::empty({frame_length}, torch::kFloat);
|
||||||
|
auto p_window = window.accessor<float, 1>();
|
||||||
|
auto p_wave = wave.accessor<float, 1>();
|
||||||
|
|
||||||
|
// Deal with any end effects by reflection, if needed. This code will only
|
||||||
|
// be reached for about two frames per utterance, so we don't concern
|
||||||
|
// ourselves excessively with efficiency.
|
||||||
|
int32_t wave_dim = wave.numel();
|
||||||
|
for (int32_t s = 0; s != frame_length; ++s) {
|
||||||
|
int32_t s_in_wave = s + wave_start;
|
||||||
|
while (s_in_wave < 0 || s_in_wave >= wave_dim) {
|
||||||
|
// reflect around the beginning or end of the wave.
|
||||||
|
// e.g. -1 -> 0, -2 -> 1.
|
||||||
|
// dim -> dim - 1, dim + 1 -> dim - 2.
|
||||||
|
// the code supports repeated reflections, although this
|
||||||
|
// would only be needed in pathological cases.
|
||||||
|
if (s_in_wave < 0) {
|
||||||
|
s_in_wave = -s_in_wave - 1;
|
||||||
|
} else {
|
||||||
|
s_in_wave = 2 * wave_dim - 1 - s_in_wave;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p_window[s] = p_wave[s_in_wave];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace kaldifeat
|
} // namespace kaldifeat
|
||||||
|
@ -44,7 +44,11 @@ struct FrameExtractionOptions {
|
|||||||
bool snip_edges = true;
|
bool snip_edges = true;
|
||||||
// bool allow_downsample = false;
|
// bool allow_downsample = false;
|
||||||
// bool allow_upsample = false;
|
// bool allow_upsample = false;
|
||||||
// int32_t max_feature_vectors = -1;
|
|
||||||
|
// Used for streaming feature extraction. It indicates the number
|
||||||
|
// of feature frames to keep in the recycling vector. -1 means to
|
||||||
|
// keep all feature frames.
|
||||||
|
int32_t max_feature_vectors = -1;
|
||||||
|
|
||||||
int32_t WindowShift() const {
|
int32_t WindowShift() const {
|
||||||
return static_cast<int32_t>(samp_freq * 0.001f * frame_shift_ms);
|
return static_cast<int32_t>(samp_freq * 0.001f * frame_shift_ms);
|
||||||
@ -71,7 +75,7 @@ struct FrameExtractionOptions {
|
|||||||
KALDIFEAT_PRINT(snip_edges);
|
KALDIFEAT_PRINT(snip_edges);
|
||||||
// KALDIFEAT_PRINT(allow_downsample);
|
// KALDIFEAT_PRINT(allow_downsample);
|
||||||
// KALDIFEAT_PRINT(allow_upsample);
|
// KALDIFEAT_PRINT(allow_upsample);
|
||||||
// KALDIFEAT_PRINT(max_feature_vectors);
|
KALDIFEAT_PRINT(max_feature_vectors);
|
||||||
#undef KALDIFEAT_PRINT
|
#undef KALDIFEAT_PRINT
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
@ -100,11 +104,11 @@ class FeatureWindowFunction {
|
|||||||
|
|
||||||
@param [in] flush True if we are asserting that this number of samples
|
@param [in] flush True if we are asserting that this number of samples
|
||||||
is 'all there is', false if we expecting more data to possibly come in. This
|
is 'all there is', false if we expecting more data to possibly come in. This
|
||||||
only makes a difference to the answer if opts.snips_edges
|
only makes a difference to the answer
|
||||||
== false. For offline feature extraction you always want flush ==
|
if opts.snips_edges== false. For offline feature extraction you always want
|
||||||
true. In an online-decoding context, once you know (or decide)
|
flush == true. In an online-decoding context, once you know (or decide) that
|
||||||
that no more data is coming in, you'd call it with flush == true at the end
|
no more data is coming in, you'd call it with flush == true at the end to
|
||||||
to flush out any remaining data.
|
flush out any remaining data.
|
||||||
*/
|
*/
|
||||||
int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts,
|
int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts,
|
||||||
bool flush = true);
|
bool flush = true);
|
||||||
@ -133,6 +137,29 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value);
|
|||||||
|
|
||||||
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave);
|
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave);
|
||||||
|
|
||||||
|
/*
|
||||||
|
ExtractWindow() extracts "frame_length" samples from the given waveform.
|
||||||
|
Note: This function only extracts "frame_length" samples
|
||||||
|
from the input waveform, without any further processing.
|
||||||
|
|
||||||
|
@param [in] sample_offset If 'wave' is not the entire waveform, but
|
||||||
|
part of it to the left has been discarded, then the
|
||||||
|
number of samples prior to 'wave' that we have
|
||||||
|
already discarded. Set this to zero if you are
|
||||||
|
processing the entire waveform in one piece, or
|
||||||
|
if you get 'no matching function' compilation
|
||||||
|
errors when updating the code.
|
||||||
|
@param [in] wave The waveform
|
||||||
|
@param [in] f The frame index to be extracted, with
|
||||||
|
0 <= f < NumFrames(sample_offset + wave.numel(), opts, true)
|
||||||
|
@param [in] opts The options class to be used
|
||||||
|
@return Return a tensor containing "frame_length" samples extracted from
|
||||||
|
`wave`, without any further processing. Its shape is
|
||||||
|
(1, frame_length).
|
||||||
|
*/
|
||||||
|
torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
|
||||||
|
int32_t f, const FrameExtractionOptions &opts);
|
||||||
|
|
||||||
} // namespace kaldifeat
|
} // namespace kaldifeat
|
||||||
|
|
||||||
#endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_
|
#endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_
|
||||||
|
@ -19,6 +19,10 @@ class OnlineFeatureInterface {
|
|||||||
virtual ~OnlineFeatureInterface() = default;
|
virtual ~OnlineFeatureInterface() = default;
|
||||||
|
|
||||||
virtual int32_t Dim() const = 0; /// returns the feature dimension.
|
virtual int32_t Dim() const = 0; /// returns the feature dimension.
|
||||||
|
//
|
||||||
|
// Returns frame shift in seconds. Helps to estimate duration from frame
|
||||||
|
// counts.
|
||||||
|
virtual float FrameShiftInSeconds() const = 0;
|
||||||
|
|
||||||
/// Returns the total number of frames, since the start of the utterance, that
|
/// Returns the total number of frames, since the start of the utterance, that
|
||||||
/// are now available. In an online-decoding context, this will likely
|
/// are now available. In an online-decoding context, this will likely
|
||||||
@ -64,10 +68,6 @@ class OnlineFeatureInterface {
|
|||||||
return torch::cat(features, /*dim*/ 0);
|
return torch::cat(features, /*dim*/ 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns frame shift in seconds. Helps to estimate duration from frame
|
|
||||||
// counts.
|
|
||||||
virtual float FrameShiftInSeconds() const = 0;
|
|
||||||
|
|
||||||
/// This would be called from the application, when you get more wave data.
|
/// This would be called from the application, when you get more wave data.
|
||||||
/// Note: the sampling_rate is typically only provided so the code can assert
|
/// Note: the sampling_rate is typically only provided so the code can assert
|
||||||
/// that it matches the sampling rate expected in the options.
|
/// that it matches the sampling rate expected in the options.
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include "kaldifeat/csrc/online-feature.h"
|
#include "kaldifeat/csrc/online-feature.h"
|
||||||
|
|
||||||
|
#include "kaldifeat/csrc/feature-window.h"
|
||||||
#include "kaldifeat/csrc/log.h"
|
#include "kaldifeat/csrc/log.h"
|
||||||
|
|
||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
@ -40,4 +41,97 @@ int32_t RecyclingVector::Size() const {
|
|||||||
return first_available_index_ + static_cast<int32_t>(items_.size());
|
return first_available_index_ + static_cast<int32_t>(items_.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class C>
|
||||||
|
OnlineGenericBaseFeature<C>::OnlineGenericBaseFeature(
|
||||||
|
const typename C::Options &opts)
|
||||||
|
: computer_(opts),
|
||||||
|
window_function_(opts.frame_opts, opts.device),
|
||||||
|
features_(opts.frame_opts.max_feature_vectors),
|
||||||
|
input_finished_(false),
|
||||||
|
waveform_offset_(0) {
|
||||||
|
// Casting to uint32_t, an unsigned type, means that -1 would be treated
|
||||||
|
// as `very large`.
|
||||||
|
KALDIFEAT_ASSERT(static_cast<uint32_t>(opts.frame_opts.max_feature_vectors) >
|
||||||
|
200);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class C>
|
||||||
|
void OnlineGenericBaseFeature<C>::AcceptWaveform(
|
||||||
|
float sampling_rate, const torch::Tensor &original_waveform) {
|
||||||
|
if (original_waveform.numel() == 0) return; // Nothing to do.
|
||||||
|
|
||||||
|
KALDIFEAT_ASSERT(original_waveform.dim() == 1);
|
||||||
|
|
||||||
|
if (input_finished_)
|
||||||
|
KALDIFEAT_ERR << "AcceptWaveform called after InputFinished() was called.";
|
||||||
|
|
||||||
|
if (waveform_remainder_.numel() == 0) {
|
||||||
|
waveform_remainder_ = original_waveform;
|
||||||
|
} else {
|
||||||
|
waveform_remainder_ =
|
||||||
|
torch::cat({waveform_remainder_, original_waveform}, /*dim*/ 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputeFeatures();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class C>
|
||||||
|
void OnlineGenericBaseFeature<C>::InputFinished() {
|
||||||
|
input_finished_ = true;
|
||||||
|
ComputeFeatures();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class C>
|
||||||
|
void OnlineGenericBaseFeature<C>::ComputeFeatures() {
|
||||||
|
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
|
||||||
|
|
||||||
|
int64_t num_samples_total = waveform_offset_ + waveform_remainder_.numel();
|
||||||
|
int32_t num_frames_old = features_.Size();
|
||||||
|
int32_t num_frames_new =
|
||||||
|
NumFrames(num_samples_total, frame_opts, input_finished_);
|
||||||
|
|
||||||
|
KALDIFEAT_ASSERT(num_frames_new >= num_frames_old);
|
||||||
|
|
||||||
|
// note: this online feature-extraction code does not support VTLN.
|
||||||
|
float vtln_warp = 1.0;
|
||||||
|
|
||||||
|
for (int32_t frame = num_frames_old; frame < num_frames_new; ++frame) {
|
||||||
|
torch::Tensor window =
|
||||||
|
ExtractWindow(waveform_offset_, waveform_remainder_, frame, frame_opts);
|
||||||
|
|
||||||
|
// TODO(fangjun): We can compute all feature frames at once
|
||||||
|
torch::Tensor this_feature =
|
||||||
|
computer_.ComputeFeatures(window.unsqueeze(0), vtln_warp);
|
||||||
|
features_.PushBack(this_feature);
|
||||||
|
}
|
||||||
|
|
||||||
|
// OK, we will now discard any portion of the signal that will not be
|
||||||
|
// necessary to compute frames in the future.
|
||||||
|
int64_t first_sample_of_next_frame =
|
||||||
|
FirstSampleOfFrame(num_frames_new, frame_opts);
|
||||||
|
int32_t samples_to_discard = first_sample_of_next_frame - waveform_offset_;
|
||||||
|
if (samples_to_discard > 0) {
|
||||||
|
// discard the leftmost part of the waveform that we no longer need.
|
||||||
|
int32_t new_num_samples = waveform_remainder_.numel() - samples_to_discard;
|
||||||
|
if (new_num_samples <= 0) {
|
||||||
|
// odd, but we'll try to handle it.
|
||||||
|
waveform_offset_ += waveform_remainder_.numel();
|
||||||
|
waveform_remainder_.resize_({0});
|
||||||
|
} else {
|
||||||
|
using torch::indexing::None;
|
||||||
|
using torch::indexing::Slice;
|
||||||
|
|
||||||
|
waveform_remainder_ =
|
||||||
|
waveform_remainder_.index({Slice(samples_to_discard, None)});
|
||||||
|
|
||||||
|
waveform_offset_ += samples_to_discard;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// instantiate the templates defined here for MFCC, PLP and filterbank classes.
|
||||||
|
template class OnlineGenericBaseFeature<Mfcc>;
|
||||||
|
template class OnlineGenericBaseFeature<Plp>;
|
||||||
|
template class OnlineGenericBaseFeature<Fbank>;
|
||||||
|
|
||||||
} // namespace kaldifeat
|
} // namespace kaldifeat
|
||||||
|
@ -9,6 +9,10 @@
|
|||||||
|
|
||||||
#include <deque>
|
#include <deque>
|
||||||
|
|
||||||
|
#include "kaldifeat/csrc/feature-fbank.h"
|
||||||
|
#include "kaldifeat/csrc/feature-mfcc.h"
|
||||||
|
#include "kaldifeat/csrc/feature-plp.h"
|
||||||
|
#include "kaldifeat/csrc/feature-window.h"
|
||||||
#include "kaldifeat/csrc/online-feature-itf.h"
|
#include "kaldifeat/csrc/online-feature-itf.h"
|
||||||
|
|
||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
@ -44,6 +48,80 @@ class RecyclingVector {
|
|||||||
int32_t first_available_index_;
|
int32_t first_available_index_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// This is a templated class for online feature extraction;
|
||||||
|
/// it's templated on a class like MfccComputer or PlpComputer
|
||||||
|
/// that does the basic feature extraction.
|
||||||
|
template <class C>
|
||||||
|
class OnlineGenericBaseFeature : public OnlineFeatureInterface {
|
||||||
|
public:
|
||||||
|
// Constructor from options class
|
||||||
|
explicit OnlineGenericBaseFeature(const typename C::Options &opts);
|
||||||
|
|
||||||
|
int32_t Dim() const override { return computer_.Dim(); }
|
||||||
|
|
||||||
|
float FrameShiftInSeconds() const override {
|
||||||
|
return computer_.GetFrameOptions().frame_shift_ms / 1000.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t NumFramesReady() const override { return features_.Size(); }
|
||||||
|
|
||||||
|
// Note: IsLastFrame() will only ever return true if you have called
|
||||||
|
// InputFinished() (and this frame is the last frame).
|
||||||
|
bool IsLastFrame(int32_t frame) const override {
|
||||||
|
return input_finished_ && frame == NumFramesReady() - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor GetFrame(int32_t frame) override { return features_.At(frame); }
|
||||||
|
|
||||||
|
// This would be called from the application, when you get
|
||||||
|
// more wave data. Note: the sampling_rate is only provided so
|
||||||
|
// the code can assert that it matches the sampling rate
|
||||||
|
// expected in the options.
|
||||||
|
void AcceptWaveform(float sampling_rate,
|
||||||
|
const torch::Tensor &waveform) override;
|
||||||
|
|
||||||
|
// InputFinished() tells the class you won't be providing any
|
||||||
|
// more waveform. This will help flush out the last frame or two
|
||||||
|
// of features, in the case where snip-edges == false; it also
|
||||||
|
// affects the return value of IsLastFrame().
|
||||||
|
void InputFinished() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// This function computes any additional feature frames that it is possible to
|
||||||
|
// compute from 'waveform_remainder_', which at this point may contain more
|
||||||
|
// than just a remainder-sized quantity (because AcceptWaveform() appends to
|
||||||
|
// waveform_remainder_ before calling this function). It adds these feature
|
||||||
|
// frames to features_, and shifts off any now-unneeded samples of input from
|
||||||
|
// waveform_remainder_ while incrementing waveform_offset_ by the same amount.
|
||||||
|
void ComputeFeatures();
|
||||||
|
|
||||||
|
C computer_; // class that does the MFCC or PLP or filterbank computation
|
||||||
|
|
||||||
|
FeatureWindowFunction window_function_;
|
||||||
|
|
||||||
|
// features_ is the Mfcc or Plp or Fbank features that we have already
|
||||||
|
// computed.
|
||||||
|
|
||||||
|
RecyclingVector features_;
|
||||||
|
|
||||||
|
// True if the user has called "InputFinished()"
|
||||||
|
bool input_finished_;
|
||||||
|
|
||||||
|
// waveform_offset_ is the number of samples of waveform that we have
|
||||||
|
// already discarded, i.e. that were prior to 'waveform_remainder_'.
|
||||||
|
int64_t waveform_offset_;
|
||||||
|
|
||||||
|
// waveform_remainder_ is a short piece of waveform that we may need to keep
|
||||||
|
// after extracting all the whole frames we can (whatever length of feature
|
||||||
|
// will be required for the next phase of computation).
|
||||||
|
// It is a 1-D tensor
|
||||||
|
torch::Tensor waveform_remainder_;
|
||||||
|
};
|
||||||
|
|
||||||
|
using OnlineMfcc = OnlineGenericBaseFeature<Mfcc>;
|
||||||
|
using OnlinePlp = OnlineGenericBaseFeature<Plp>;
|
||||||
|
using OnlineFbank = OnlineGenericBaseFeature<Fbank>;
|
||||||
|
|
||||||
} // namespace kaldifeat
|
} // namespace kaldifeat
|
||||||
|
|
||||||
#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_H_
|
#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_H_
|
||||||
|
@ -25,6 +25,7 @@ static void PybindFrameExtractionOptions(py::module &m) {
|
|||||||
.def_readwrite("round_to_power_of_two", &PyClass::round_to_power_of_two)
|
.def_readwrite("round_to_power_of_two", &PyClass::round_to_power_of_two)
|
||||||
.def_readwrite("blackman_coeff", &PyClass::blackman_coeff)
|
.def_readwrite("blackman_coeff", &PyClass::blackman_coeff)
|
||||||
.def_readwrite("snip_edges", &PyClass::snip_edges)
|
.def_readwrite("snip_edges", &PyClass::snip_edges)
|
||||||
|
.def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)
|
||||||
.def("as_dict",
|
.def("as_dict",
|
||||||
[](const PyClass &self) -> py::dict { return AsDict(self); })
|
[](const PyClass &self) -> py::dict { return AsDict(self); })
|
||||||
.def_static("from_dict",
|
.def_static("from_dict",
|
||||||
@ -35,8 +36,6 @@ static void PybindFrameExtractionOptions(py::module &m) {
|
|||||||
.def_readwrite("allow_downsample",
|
.def_readwrite("allow_downsample",
|
||||||
&PyClass::allow_downsample)
|
&PyClass::allow_downsample)
|
||||||
.def_readwrite("allow_upsample", &PyClass::allow_upsample)
|
.def_readwrite("allow_upsample", &PyClass::allow_upsample)
|
||||||
.def_readwrite("max_feature_vectors",
|
|
||||||
&PyClass::max_feature_vectors)
|
|
||||||
#endif
|
#endif
|
||||||
.def("__str__",
|
.def("__str__",
|
||||||
[](const PyClass &self) -> std::string { return self.ToString(); })
|
[](const PyClass &self) -> std::string { return self.ToString(); })
|
||||||
|
@ -30,6 +30,7 @@ FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict) {
|
|||||||
FROM_DICT(bool_, round_to_power_of_two);
|
FROM_DICT(bool_, round_to_power_of_two);
|
||||||
FROM_DICT(float_, blackman_coeff);
|
FROM_DICT(float_, blackman_coeff);
|
||||||
FROM_DICT(bool_, snip_edges);
|
FROM_DICT(bool_, snip_edges);
|
||||||
|
FROM_DICT(int_, max_feature_vectors);
|
||||||
|
|
||||||
return opts;
|
return opts;
|
||||||
}
|
}
|
||||||
@ -47,6 +48,7 @@ py::dict AsDict(const FrameExtractionOptions &opts) {
|
|||||||
AS_DICT(round_to_power_of_two);
|
AS_DICT(round_to_power_of_two);
|
||||||
AS_DICT(blackman_coeff);
|
AS_DICT(blackman_coeff);
|
||||||
AS_DICT(snip_edges);
|
AS_DICT(snip_edges);
|
||||||
|
AS_DICT(max_feature_vectors);
|
||||||
|
|
||||||
return dict;
|
return dict;
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user