mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 18:12:17 +00:00
Add OnlineGenericBaseFeature.
This commit is contained in:
parent
34ba30272d
commit
039e27dd32
@ -62,6 +62,10 @@ class OfflineFeatureTpl {
|
||||
int32_t Dim() const { return computer_.Dim(); }
|
||||
const Options &GetOptions() const { return computer_.GetOptions(); }
|
||||
|
||||
const FrameExtractionOptions &GetFrameOptions() const {
|
||||
return GetOptions().frame_opts;
|
||||
}
|
||||
|
||||
// Copy constructor.
|
||||
OfflineFeatureTpl(const OfflineFeatureTpl<F> &) = delete;
|
||||
OfflineFeatureTpl<F> &operator=(const OfflineFeatureTpl<F> &) = delete;
|
||||
|
@ -161,19 +161,20 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value) {
|
||||
#if 1
|
||||
return wave + rand_gauss * dither_value;
|
||||
#else
|
||||
// use in-place version of wave and change its to pointer type
|
||||
// use in-place version of wave and change it to pointer type
|
||||
wave_->add_(rand_gauss, dither_value);
|
||||
#endif
|
||||
}
|
||||
|
||||
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) {
|
||||
using namespace torch::indexing; // It imports: Slice, None // NOLINT
|
||||
if (preemph_coeff == 0.0f) return wave;
|
||||
|
||||
KALDIFEAT_ASSERT(preemph_coeff >= 0.0f && preemph_coeff <= 1.0f);
|
||||
|
||||
torch::Tensor ans = torch::empty_like(wave);
|
||||
|
||||
using torch::indexing::None;
|
||||
using torch::indexing::Slice;
|
||||
// right = wave[:, 1:]
|
||||
torch::Tensor right = wave.index({"...", Slice(1, None, None)});
|
||||
|
||||
@ -188,4 +189,58 @@ torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) {
|
||||
return ans;
|
||||
}
|
||||
|
||||
torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
|
||||
int32_t f, const FrameExtractionOptions &opts) {
|
||||
KALDIFEAT_ASSERT(sample_offset >= 0 && wave.numel() != 0);
|
||||
|
||||
int32_t frame_length = opts.WindowSize();
|
||||
int64_t num_samples = sample_offset + wave.numel();
|
||||
int64_t start_sample = FirstSampleOfFrame(f, opts);
|
||||
int64_t end_sample = start_sample + frame_length;
|
||||
|
||||
if (opts.snip_edges) {
|
||||
KALDIFEAT_ASSERT(start_sample >= sample_offset &&
|
||||
end_sample <= num_samples);
|
||||
} else {
|
||||
KALDIFEAT_ASSERT(sample_offset == 0 || start_sample >= sample_offset);
|
||||
}
|
||||
|
||||
// wave_start and wave_end are start and end indexes into 'wave', for the
|
||||
// piece of wave that we're trying to extract.
|
||||
int32_t wave_start = static_cast<int32_t>(start_sample - sample_offset);
|
||||
int32_t wave_end = wave_start + frame_length;
|
||||
|
||||
if (wave_start >= 0 && wave_end <= wave.numel()) {
|
||||
// the normal case -- no edge effects to consider.
|
||||
// return wave[wave_start:wave_end]
|
||||
return wave.index({torch::indexing::Slice(wave_start, wave_end)});
|
||||
} else {
|
||||
torch::Tensor window = torch::empty({frame_length}, torch::kFloat);
|
||||
auto p_window = window.accessor<float, 1>();
|
||||
auto p_wave = wave.accessor<float, 1>();
|
||||
|
||||
// Deal with any end effects by reflection, if needed. This code will only
|
||||
// be reached for about two frames per utterance, so we don't concern
|
||||
// ourselves excessively with efficiency.
|
||||
int32_t wave_dim = wave.numel();
|
||||
for (int32_t s = 0; s != frame_length; ++s) {
|
||||
int32_t s_in_wave = s + wave_start;
|
||||
while (s_in_wave < 0 || s_in_wave >= wave_dim) {
|
||||
// reflect around the beginning or end of the wave.
|
||||
// e.g. -1 -> 0, -2 -> 1.
|
||||
// dim -> dim - 1, dim + 1 -> dim - 2.
|
||||
// the code supports repeated reflections, although this
|
||||
// would only be needed in pathological cases.
|
||||
if (s_in_wave < 0) {
|
||||
s_in_wave = -s_in_wave - 1;
|
||||
} else {
|
||||
s_in_wave = 2 * wave_dim - 1 - s_in_wave;
|
||||
}
|
||||
}
|
||||
|
||||
p_window[s] = p_wave[s_in_wave];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
@ -44,7 +44,11 @@ struct FrameExtractionOptions {
|
||||
bool snip_edges = true;
|
||||
// bool allow_downsample = false;
|
||||
// bool allow_upsample = false;
|
||||
// int32_t max_feature_vectors = -1;
|
||||
|
||||
// Used for streaming feature extraction. It indicates the number
|
||||
// of feature frames to keep in the recycling vector. -1 means to
|
||||
// keep all feature frames.
|
||||
int32_t max_feature_vectors = -1;
|
||||
|
||||
int32_t WindowShift() const {
|
||||
return static_cast<int32_t>(samp_freq * 0.001f * frame_shift_ms);
|
||||
@ -71,7 +75,7 @@ struct FrameExtractionOptions {
|
||||
KALDIFEAT_PRINT(snip_edges);
|
||||
// KALDIFEAT_PRINT(allow_downsample);
|
||||
// KALDIFEAT_PRINT(allow_upsample);
|
||||
// KALDIFEAT_PRINT(max_feature_vectors);
|
||||
KALDIFEAT_PRINT(max_feature_vectors);
|
||||
#undef KALDIFEAT_PRINT
|
||||
return os.str();
|
||||
}
|
||||
@ -100,11 +104,11 @@ class FeatureWindowFunction {
|
||||
|
||||
@param [in] flush True if we are asserting that this number of samples
|
||||
is 'all there is', false if we expecting more data to possibly come in. This
|
||||
only makes a difference to the answer if opts.snips_edges
|
||||
== false. For offline feature extraction you always want flush ==
|
||||
true. In an online-decoding context, once you know (or decide)
|
||||
that no more data is coming in, you'd call it with flush == true at the end
|
||||
to flush out any remaining data.
|
||||
only makes a difference to the answer
|
||||
if opts.snips_edges== false. For offline feature extraction you always want
|
||||
flush == true. In an online-decoding context, once you know (or decide) that
|
||||
no more data is coming in, you'd call it with flush == true at the end to
|
||||
flush out any remaining data.
|
||||
*/
|
||||
int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts,
|
||||
bool flush = true);
|
||||
@ -133,6 +137,29 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value);
|
||||
|
||||
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave);
|
||||
|
||||
/*
|
||||
ExtractWindow() extracts "frame_length" samples from the given waveform.
|
||||
Note: This function only extracts "frame_length" samples
|
||||
from the input waveform, without any further processing.
|
||||
|
||||
@param [in] sample_offset If 'wave' is not the entire waveform, but
|
||||
part of it to the left has been discarded, then the
|
||||
number of samples prior to 'wave' that we have
|
||||
already discarded. Set this to zero if you are
|
||||
processing the entire waveform in one piece, or
|
||||
if you get 'no matching function' compilation
|
||||
errors when updating the code.
|
||||
@param [in] wave The waveform
|
||||
@param [in] f The frame index to be extracted, with
|
||||
0 <= f < NumFrames(sample_offset + wave.numel(), opts, true)
|
||||
@param [in] opts The options class to be used
|
||||
@return Return a tensor containing "frame_length" samples extracted from
|
||||
`wave`, without any further processing. Its shape is
|
||||
(1, frame_length).
|
||||
*/
|
||||
torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
|
||||
int32_t f, const FrameExtractionOptions &opts);
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
#endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_
|
||||
|
@ -19,6 +19,10 @@ class OnlineFeatureInterface {
|
||||
virtual ~OnlineFeatureInterface() = default;
|
||||
|
||||
virtual int32_t Dim() const = 0; /// returns the feature dimension.
|
||||
//
|
||||
// Returns frame shift in seconds. Helps to estimate duration from frame
|
||||
// counts.
|
||||
virtual float FrameShiftInSeconds() const = 0;
|
||||
|
||||
/// Returns the total number of frames, since the start of the utterance, that
|
||||
/// are now available. In an online-decoding context, this will likely
|
||||
@ -64,10 +68,6 @@ class OnlineFeatureInterface {
|
||||
return torch::cat(features, /*dim*/ 0);
|
||||
}
|
||||
|
||||
// Returns frame shift in seconds. Helps to estimate duration from frame
|
||||
// counts.
|
||||
virtual float FrameShiftInSeconds() const = 0;
|
||||
|
||||
/// This would be called from the application, when you get more wave data.
|
||||
/// Note: the sampling_rate is typically only provided so the code can assert
|
||||
/// that it matches the sampling rate expected in the options.
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "kaldifeat/csrc/online-feature.h"
|
||||
|
||||
#include "kaldifeat/csrc/feature-window.h"
|
||||
#include "kaldifeat/csrc/log.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
@ -40,4 +41,97 @@ int32_t RecyclingVector::Size() const {
|
||||
return first_available_index_ + static_cast<int32_t>(items_.size());
|
||||
}
|
||||
|
||||
template <class C>
|
||||
OnlineGenericBaseFeature<C>::OnlineGenericBaseFeature(
|
||||
const typename C::Options &opts)
|
||||
: computer_(opts),
|
||||
window_function_(opts.frame_opts, opts.device),
|
||||
features_(opts.frame_opts.max_feature_vectors),
|
||||
input_finished_(false),
|
||||
waveform_offset_(0) {
|
||||
// Casting to uint32_t, an unsigned type, means that -1 would be treated
|
||||
// as `very large`.
|
||||
KALDIFEAT_ASSERT(static_cast<uint32_t>(opts.frame_opts.max_feature_vectors) >
|
||||
200);
|
||||
}
|
||||
|
||||
template <class C>
|
||||
void OnlineGenericBaseFeature<C>::AcceptWaveform(
|
||||
float sampling_rate, const torch::Tensor &original_waveform) {
|
||||
if (original_waveform.numel() == 0) return; // Nothing to do.
|
||||
|
||||
KALDIFEAT_ASSERT(original_waveform.dim() == 1);
|
||||
|
||||
if (input_finished_)
|
||||
KALDIFEAT_ERR << "AcceptWaveform called after InputFinished() was called.";
|
||||
|
||||
if (waveform_remainder_.numel() == 0) {
|
||||
waveform_remainder_ = original_waveform;
|
||||
} else {
|
||||
waveform_remainder_ =
|
||||
torch::cat({waveform_remainder_, original_waveform}, /*dim*/ 0);
|
||||
}
|
||||
|
||||
ComputeFeatures();
|
||||
}
|
||||
|
||||
template <class C>
|
||||
void OnlineGenericBaseFeature<C>::InputFinished() {
|
||||
input_finished_ = true;
|
||||
ComputeFeatures();
|
||||
}
|
||||
|
||||
template <class C>
|
||||
void OnlineGenericBaseFeature<C>::ComputeFeatures() {
|
||||
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
|
||||
|
||||
int64_t num_samples_total = waveform_offset_ + waveform_remainder_.numel();
|
||||
int32_t num_frames_old = features_.Size();
|
||||
int32_t num_frames_new =
|
||||
NumFrames(num_samples_total, frame_opts, input_finished_);
|
||||
|
||||
KALDIFEAT_ASSERT(num_frames_new >= num_frames_old);
|
||||
|
||||
// note: this online feature-extraction code does not support VTLN.
|
||||
float vtln_warp = 1.0;
|
||||
|
||||
for (int32_t frame = num_frames_old; frame < num_frames_new; ++frame) {
|
||||
torch::Tensor window =
|
||||
ExtractWindow(waveform_offset_, waveform_remainder_, frame, frame_opts);
|
||||
|
||||
// TODO(fangjun): We can compute all feature frames at once
|
||||
torch::Tensor this_feature =
|
||||
computer_.ComputeFeatures(window.unsqueeze(0), vtln_warp);
|
||||
features_.PushBack(this_feature);
|
||||
}
|
||||
|
||||
// OK, we will now discard any portion of the signal that will not be
|
||||
// necessary to compute frames in the future.
|
||||
int64_t first_sample_of_next_frame =
|
||||
FirstSampleOfFrame(num_frames_new, frame_opts);
|
||||
int32_t samples_to_discard = first_sample_of_next_frame - waveform_offset_;
|
||||
if (samples_to_discard > 0) {
|
||||
// discard the leftmost part of the waveform that we no longer need.
|
||||
int32_t new_num_samples = waveform_remainder_.numel() - samples_to_discard;
|
||||
if (new_num_samples <= 0) {
|
||||
// odd, but we'll try to handle it.
|
||||
waveform_offset_ += waveform_remainder_.numel();
|
||||
waveform_remainder_.resize_({0});
|
||||
} else {
|
||||
using torch::indexing::None;
|
||||
using torch::indexing::Slice;
|
||||
|
||||
waveform_remainder_ =
|
||||
waveform_remainder_.index({Slice(samples_to_discard, None)});
|
||||
|
||||
waveform_offset_ += samples_to_discard;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// instantiate the templates defined here for MFCC, PLP and filterbank classes.
|
||||
template class OnlineGenericBaseFeature<Mfcc>;
|
||||
template class OnlineGenericBaseFeature<Plp>;
|
||||
template class OnlineGenericBaseFeature<Fbank>;
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
@ -9,6 +9,10 @@
|
||||
|
||||
#include <deque>
|
||||
|
||||
#include "kaldifeat/csrc/feature-fbank.h"
|
||||
#include "kaldifeat/csrc/feature-mfcc.h"
|
||||
#include "kaldifeat/csrc/feature-plp.h"
|
||||
#include "kaldifeat/csrc/feature-window.h"
|
||||
#include "kaldifeat/csrc/online-feature-itf.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
@ -44,6 +48,80 @@ class RecyclingVector {
|
||||
int32_t first_available_index_;
|
||||
};
|
||||
|
||||
/// This is a templated class for online feature extraction;
|
||||
/// it's templated on a class like MfccComputer or PlpComputer
|
||||
/// that does the basic feature extraction.
|
||||
template <class C>
|
||||
class OnlineGenericBaseFeature : public OnlineFeatureInterface {
|
||||
public:
|
||||
// Constructor from options class
|
||||
explicit OnlineGenericBaseFeature(const typename C::Options &opts);
|
||||
|
||||
int32_t Dim() const override { return computer_.Dim(); }
|
||||
|
||||
float FrameShiftInSeconds() const override {
|
||||
return computer_.GetFrameOptions().frame_shift_ms / 1000.0f;
|
||||
}
|
||||
|
||||
int32_t NumFramesReady() const override { return features_.Size(); }
|
||||
|
||||
// Note: IsLastFrame() will only ever return true if you have called
|
||||
// InputFinished() (and this frame is the last frame).
|
||||
bool IsLastFrame(int32_t frame) const override {
|
||||
return input_finished_ && frame == NumFramesReady() - 1;
|
||||
}
|
||||
|
||||
torch::Tensor GetFrame(int32_t frame) override { return features_.At(frame); }
|
||||
|
||||
// This would be called from the application, when you get
|
||||
// more wave data. Note: the sampling_rate is only provided so
|
||||
// the code can assert that it matches the sampling rate
|
||||
// expected in the options.
|
||||
void AcceptWaveform(float sampling_rate,
|
||||
const torch::Tensor &waveform) override;
|
||||
|
||||
// InputFinished() tells the class you won't be providing any
|
||||
// more waveform. This will help flush out the last frame or two
|
||||
// of features, in the case where snip-edges == false; it also
|
||||
// affects the return value of IsLastFrame().
|
||||
void InputFinished() override;
|
||||
|
||||
private:
|
||||
// This function computes any additional feature frames that it is possible to
|
||||
// compute from 'waveform_remainder_', which at this point may contain more
|
||||
// than just a remainder-sized quantity (because AcceptWaveform() appends to
|
||||
// waveform_remainder_ before calling this function). It adds these feature
|
||||
// frames to features_, and shifts off any now-unneeded samples of input from
|
||||
// waveform_remainder_ while incrementing waveform_offset_ by the same amount.
|
||||
void ComputeFeatures();
|
||||
|
||||
C computer_; // class that does the MFCC or PLP or filterbank computation
|
||||
|
||||
FeatureWindowFunction window_function_;
|
||||
|
||||
// features_ is the Mfcc or Plp or Fbank features that we have already
|
||||
// computed.
|
||||
|
||||
RecyclingVector features_;
|
||||
|
||||
// True if the user has called "InputFinished()"
|
||||
bool input_finished_;
|
||||
|
||||
// waveform_offset_ is the number of samples of waveform that we have
|
||||
// already discarded, i.e. that were prior to 'waveform_remainder_'.
|
||||
int64_t waveform_offset_;
|
||||
|
||||
// waveform_remainder_ is a short piece of waveform that we may need to keep
|
||||
// after extracting all the whole frames we can (whatever length of feature
|
||||
// will be required for the next phase of computation).
|
||||
// It is a 1-D tensor
|
||||
torch::Tensor waveform_remainder_;
|
||||
};
|
||||
|
||||
using OnlineMfcc = OnlineGenericBaseFeature<Mfcc>;
|
||||
using OnlinePlp = OnlineGenericBaseFeature<Plp>;
|
||||
using OnlineFbank = OnlineGenericBaseFeature<Fbank>;
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_H_
|
||||
|
@ -25,6 +25,7 @@ static void PybindFrameExtractionOptions(py::module &m) {
|
||||
.def_readwrite("round_to_power_of_two", &PyClass::round_to_power_of_two)
|
||||
.def_readwrite("blackman_coeff", &PyClass::blackman_coeff)
|
||||
.def_readwrite("snip_edges", &PyClass::snip_edges)
|
||||
.def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)
|
||||
.def("as_dict",
|
||||
[](const PyClass &self) -> py::dict { return AsDict(self); })
|
||||
.def_static("from_dict",
|
||||
@ -35,8 +36,6 @@ static void PybindFrameExtractionOptions(py::module &m) {
|
||||
.def_readwrite("allow_downsample",
|
||||
&PyClass::allow_downsample)
|
||||
.def_readwrite("allow_upsample", &PyClass::allow_upsample)
|
||||
.def_readwrite("max_feature_vectors",
|
||||
&PyClass::max_feature_vectors)
|
||||
#endif
|
||||
.def("__str__",
|
||||
[](const PyClass &self) -> std::string { return self.ToString(); })
|
||||
|
@ -30,6 +30,7 @@ FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict) {
|
||||
FROM_DICT(bool_, round_to_power_of_two);
|
||||
FROM_DICT(float_, blackman_coeff);
|
||||
FROM_DICT(bool_, snip_edges);
|
||||
FROM_DICT(int_, max_feature_vectors);
|
||||
|
||||
return opts;
|
||||
}
|
||||
@ -47,6 +48,7 @@ py::dict AsDict(const FrameExtractionOptions &opts) {
|
||||
AS_DICT(round_to_power_of_two);
|
||||
AS_DICT(blackman_coeff);
|
||||
AS_DICT(snip_edges);
|
||||
AS_DICT(max_feature_vectors);
|
||||
|
||||
return dict;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user