Add OnlineGenericBaseFeature.

This commit is contained in:
Fangjun Kuang 2022-04-02 17:16:10 +08:00
parent 34ba30272d
commit 039e27dd32
8 changed files with 274 additions and 15 deletions

View File

@ -62,6 +62,10 @@ class OfflineFeatureTpl {
int32_t Dim() const { return computer_.Dim(); }
const Options &GetOptions() const { return computer_.GetOptions(); }
const FrameExtractionOptions &GetFrameOptions() const {
return GetOptions().frame_opts;
}
// Copy constructor.
OfflineFeatureTpl(const OfflineFeatureTpl<F> &) = delete;
OfflineFeatureTpl<F> &operator=(const OfflineFeatureTpl<F> &) = delete;

View File

@ -161,19 +161,20 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value) {
#if 1
return wave + rand_gauss * dither_value;
#else
// use in-place version of wave and change its to pointer type
// use in-place version of wave and change it to pointer type
wave_->add_(rand_gauss, dither_value);
#endif
}
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) {
using namespace torch::indexing; // It imports: Slice, None // NOLINT
if (preemph_coeff == 0.0f) return wave;
KALDIFEAT_ASSERT(preemph_coeff >= 0.0f && preemph_coeff <= 1.0f);
torch::Tensor ans = torch::empty_like(wave);
using torch::indexing::None;
using torch::indexing::Slice;
// right = wave[:, 1:]
torch::Tensor right = wave.index({"...", Slice(1, None, None)});
@ -188,4 +189,58 @@ torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) {
return ans;
}
torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
int32_t f, const FrameExtractionOptions &opts) {
KALDIFEAT_ASSERT(sample_offset >= 0 && wave.numel() != 0);
int32_t frame_length = opts.WindowSize();
int64_t num_samples = sample_offset + wave.numel();
int64_t start_sample = FirstSampleOfFrame(f, opts);
int64_t end_sample = start_sample + frame_length;
if (opts.snip_edges) {
KALDIFEAT_ASSERT(start_sample >= sample_offset &&
end_sample <= num_samples);
} else {
KALDIFEAT_ASSERT(sample_offset == 0 || start_sample >= sample_offset);
}
// wave_start and wave_end are start and end indexes into 'wave', for the
// piece of wave that we're trying to extract.
int32_t wave_start = static_cast<int32_t>(start_sample - sample_offset);
int32_t wave_end = wave_start + frame_length;
if (wave_start >= 0 && wave_end <= wave.numel()) {
// the normal case -- no edge effects to consider.
// return wave[wave_start:wave_end]
return wave.index({torch::indexing::Slice(wave_start, wave_end)});
} else {
torch::Tensor window = torch::empty({frame_length}, torch::kFloat);
auto p_window = window.accessor<float, 1>();
auto p_wave = wave.accessor<float, 1>();
// Deal with any end effects by reflection, if needed. This code will only
// be reached for about two frames per utterance, so we don't concern
// ourselves excessively with efficiency.
int32_t wave_dim = wave.numel();
for (int32_t s = 0; s != frame_length; ++s) {
int32_t s_in_wave = s + wave_start;
while (s_in_wave < 0 || s_in_wave >= wave_dim) {
// reflect around the beginning or end of the wave.
// e.g. -1 -> 0, -2 -> 1.
// dim -> dim - 1, dim + 1 -> dim - 2.
// the code supports repeated reflections, although this
// would only be needed in pathological cases.
if (s_in_wave < 0) {
s_in_wave = -s_in_wave - 1;
} else {
s_in_wave = 2 * wave_dim - 1 - s_in_wave;
}
}
p_window[s] = p_wave[s_in_wave];
}
}
}
} // namespace kaldifeat

View File

@ -44,7 +44,11 @@ struct FrameExtractionOptions {
bool snip_edges = true;
// bool allow_downsample = false;
// bool allow_upsample = false;
// int32_t max_feature_vectors = -1;
// Used for streaming feature extraction. It indicates the number
// of feature frames to keep in the recycling vector. -1 means to
// keep all feature frames.
int32_t max_feature_vectors = -1;
int32_t WindowShift() const {
return static_cast<int32_t>(samp_freq * 0.001f * frame_shift_ms);
@ -71,7 +75,7 @@ struct FrameExtractionOptions {
KALDIFEAT_PRINT(snip_edges);
// KALDIFEAT_PRINT(allow_downsample);
// KALDIFEAT_PRINT(allow_upsample);
// KALDIFEAT_PRINT(max_feature_vectors);
KALDIFEAT_PRINT(max_feature_vectors);
#undef KALDIFEAT_PRINT
return os.str();
}
@ -100,11 +104,11 @@ class FeatureWindowFunction {
@param [in] flush True if we are asserting that this number of samples
is 'all there is', false if we expecting more data to possibly come in. This
only makes a difference to the answer if opts.snips_edges
== false. For offline feature extraction you always want flush ==
true. In an online-decoding context, once you know (or decide)
that no more data is coming in, you'd call it with flush == true at the end
to flush out any remaining data.
only makes a difference to the answer
if opts.snips_edges== false. For offline feature extraction you always want
flush == true. In an online-decoding context, once you know (or decide) that
no more data is coming in, you'd call it with flush == true at the end to
flush out any remaining data.
*/
int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts,
bool flush = true);
@ -133,6 +137,29 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value);
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave);
/*
ExtractWindow() extracts "frame_length" samples from the given waveform.
Note: This function only extracts "frame_length" samples
from the input waveform, without any further processing.
@param [in] sample_offset If 'wave' is not the entire waveform, but
part of it to the left has been discarded, then the
number of samples prior to 'wave' that we have
already discarded. Set this to zero if you are
processing the entire waveform in one piece, or
if you get 'no matching function' compilation
errors when updating the code.
@param [in] wave The waveform
@param [in] f The frame index to be extracted, with
0 <= f < NumFrames(sample_offset + wave.numel(), opts, true)
@param [in] opts The options class to be used
@return Return a tensor containing "frame_length" samples extracted from
`wave`, without any further processing. Its shape is
(1, frame_length).
*/
torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
int32_t f, const FrameExtractionOptions &opts);
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_

View File

@ -19,6 +19,10 @@ class OnlineFeatureInterface {
virtual ~OnlineFeatureInterface() = default;
virtual int32_t Dim() const = 0; /// returns the feature dimension.
//
// Returns frame shift in seconds. Helps to estimate duration from frame
// counts.
virtual float FrameShiftInSeconds() const = 0;
/// Returns the total number of frames, since the start of the utterance, that
/// are now available. In an online-decoding context, this will likely
@ -64,10 +68,6 @@ class OnlineFeatureInterface {
return torch::cat(features, /*dim*/ 0);
}
// Returns frame shift in seconds. Helps to estimate duration from frame
// counts.
virtual float FrameShiftInSeconds() const = 0;
/// This would be called from the application, when you get more wave data.
/// Note: the sampling_rate is typically only provided so the code can assert
/// that it matches the sampling rate expected in the options.

View File

@ -6,6 +6,7 @@
#include "kaldifeat/csrc/online-feature.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/log.h"
namespace kaldifeat {
@ -40,4 +41,97 @@ int32_t RecyclingVector::Size() const {
return first_available_index_ + static_cast<int32_t>(items_.size());
}
template <class C>
OnlineGenericBaseFeature<C>::OnlineGenericBaseFeature(
const typename C::Options &opts)
: computer_(opts),
window_function_(opts.frame_opts, opts.device),
features_(opts.frame_opts.max_feature_vectors),
input_finished_(false),
waveform_offset_(0) {
// Casting to uint32_t, an unsigned type, means that -1 would be treated
// as `very large`.
KALDIFEAT_ASSERT(static_cast<uint32_t>(opts.frame_opts.max_feature_vectors) >
200);
}
template <class C>
void OnlineGenericBaseFeature<C>::AcceptWaveform(
float sampling_rate, const torch::Tensor &original_waveform) {
if (original_waveform.numel() == 0) return; // Nothing to do.
KALDIFEAT_ASSERT(original_waveform.dim() == 1);
if (input_finished_)
KALDIFEAT_ERR << "AcceptWaveform called after InputFinished() was called.";
if (waveform_remainder_.numel() == 0) {
waveform_remainder_ = original_waveform;
} else {
waveform_remainder_ =
torch::cat({waveform_remainder_, original_waveform}, /*dim*/ 0);
}
ComputeFeatures();
}
template <class C>
void OnlineGenericBaseFeature<C>::InputFinished() {
input_finished_ = true;
ComputeFeatures();
}
template <class C>
void OnlineGenericBaseFeature<C>::ComputeFeatures() {
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
int64_t num_samples_total = waveform_offset_ + waveform_remainder_.numel();
int32_t num_frames_old = features_.Size();
int32_t num_frames_new =
NumFrames(num_samples_total, frame_opts, input_finished_);
KALDIFEAT_ASSERT(num_frames_new >= num_frames_old);
// note: this online feature-extraction code does not support VTLN.
float vtln_warp = 1.0;
for (int32_t frame = num_frames_old; frame < num_frames_new; ++frame) {
torch::Tensor window =
ExtractWindow(waveform_offset_, waveform_remainder_, frame, frame_opts);
// TODO(fangjun): We can compute all feature frames at once
torch::Tensor this_feature =
computer_.ComputeFeatures(window.unsqueeze(0), vtln_warp);
features_.PushBack(this_feature);
}
// OK, we will now discard any portion of the signal that will not be
// necessary to compute frames in the future.
int64_t first_sample_of_next_frame =
FirstSampleOfFrame(num_frames_new, frame_opts);
int32_t samples_to_discard = first_sample_of_next_frame - waveform_offset_;
if (samples_to_discard > 0) {
// discard the leftmost part of the waveform that we no longer need.
int32_t new_num_samples = waveform_remainder_.numel() - samples_to_discard;
if (new_num_samples <= 0) {
// odd, but we'll try to handle it.
waveform_offset_ += waveform_remainder_.numel();
waveform_remainder_.resize_({0});
} else {
using torch::indexing::None;
using torch::indexing::Slice;
waveform_remainder_ =
waveform_remainder_.index({Slice(samples_to_discard, None)});
waveform_offset_ += samples_to_discard;
}
}
}
// instantiate the templates defined here for MFCC, PLP and filterbank classes.
template class OnlineGenericBaseFeature<Mfcc>;
template class OnlineGenericBaseFeature<Plp>;
template class OnlineGenericBaseFeature<Fbank>;
} // namespace kaldifeat

View File

@ -9,6 +9,10 @@
#include <deque>
#include "kaldifeat/csrc/feature-fbank.h"
#include "kaldifeat/csrc/feature-mfcc.h"
#include "kaldifeat/csrc/feature-plp.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/online-feature-itf.h"
namespace kaldifeat {
@ -44,6 +48,80 @@ class RecyclingVector {
int32_t first_available_index_;
};
/// This is a templated class for online feature extraction;
/// it's templated on a class like MfccComputer or PlpComputer
/// that does the basic feature extraction.
template <class C>
class OnlineGenericBaseFeature : public OnlineFeatureInterface {
public:
// Constructor from options class
explicit OnlineGenericBaseFeature(const typename C::Options &opts);
int32_t Dim() const override { return computer_.Dim(); }
float FrameShiftInSeconds() const override {
return computer_.GetFrameOptions().frame_shift_ms / 1000.0f;
}
int32_t NumFramesReady() const override { return features_.Size(); }
// Note: IsLastFrame() will only ever return true if you have called
// InputFinished() (and this frame is the last frame).
bool IsLastFrame(int32_t frame) const override {
return input_finished_ && frame == NumFramesReady() - 1;
}
torch::Tensor GetFrame(int32_t frame) override { return features_.At(frame); }
// This would be called from the application, when you get
// more wave data. Note: the sampling_rate is only provided so
// the code can assert that it matches the sampling rate
// expected in the options.
void AcceptWaveform(float sampling_rate,
const torch::Tensor &waveform) override;
// InputFinished() tells the class you won't be providing any
// more waveform. This will help flush out the last frame or two
// of features, in the case where snip-edges == false; it also
// affects the return value of IsLastFrame().
void InputFinished() override;
private:
// This function computes any additional feature frames that it is possible to
// compute from 'waveform_remainder_', which at this point may contain more
// than just a remainder-sized quantity (because AcceptWaveform() appends to
// waveform_remainder_ before calling this function). It adds these feature
// frames to features_, and shifts off any now-unneeded samples of input from
// waveform_remainder_ while incrementing waveform_offset_ by the same amount.
void ComputeFeatures();
C computer_; // class that does the MFCC or PLP or filterbank computation
FeatureWindowFunction window_function_;
// features_ is the Mfcc or Plp or Fbank features that we have already
// computed.
RecyclingVector features_;
// True if the user has called "InputFinished()"
bool input_finished_;
// waveform_offset_ is the number of samples of waveform that we have
// already discarded, i.e. that were prior to 'waveform_remainder_'.
int64_t waveform_offset_;
// waveform_remainder_ is a short piece of waveform that we may need to keep
// after extracting all the whole frames we can (whatever length of feature
// will be required for the next phase of computation).
// It is a 1-D tensor
torch::Tensor waveform_remainder_;
};
using OnlineMfcc = OnlineGenericBaseFeature<Mfcc>;
using OnlinePlp = OnlineGenericBaseFeature<Plp>;
using OnlineFbank = OnlineGenericBaseFeature<Fbank>;
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_H_

View File

@ -25,6 +25,7 @@ static void PybindFrameExtractionOptions(py::module &m) {
.def_readwrite("round_to_power_of_two", &PyClass::round_to_power_of_two)
.def_readwrite("blackman_coeff", &PyClass::blackman_coeff)
.def_readwrite("snip_edges", &PyClass::snip_edges)
.def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)
.def("as_dict",
[](const PyClass &self) -> py::dict { return AsDict(self); })
.def_static("from_dict",
@ -35,8 +36,6 @@ static void PybindFrameExtractionOptions(py::module &m) {
.def_readwrite("allow_downsample",
&PyClass::allow_downsample)
.def_readwrite("allow_upsample", &PyClass::allow_upsample)
.def_readwrite("max_feature_vectors",
&PyClass::max_feature_vectors)
#endif
.def("__str__",
[](const PyClass &self) -> std::string { return self.ToString(); })

View File

@ -30,6 +30,7 @@ FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict) {
FROM_DICT(bool_, round_to_power_of_two);
FROM_DICT(float_, blackman_coeff);
FROM_DICT(bool_, snip_edges);
FROM_DICT(int_, max_feature_vectors);
return opts;
}
@ -47,6 +48,7 @@ py::dict AsDict(const FrameExtractionOptions &opts) {
AS_DICT(round_to_power_of_two);
AS_DICT(blackman_coeff);
AS_DICT(snip_edges);
AS_DICT(max_feature_vectors);
return dict;
}