mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 10:02:20 +00:00
Merge pull request #28 from csukuangfj/streaming-feature-extractor
Start to add streaming feature extractors.
This commit is contained in:
commit
b72fc599fd
4
.github/workflows/style_check.yml
vendored
4
.github/workflows/style_check.yml
vendored
@ -45,7 +45,9 @@ jobs:
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python3 -m pip install --upgrade pip black flake8
|
||||
python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
|
||||
# See https://github.com/psf/black/issues/2964
|
||||
# The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
|
||||
|
||||
- name: Run flake8
|
||||
shell: bash
|
||||
|
35
README.md
35
README.md
@ -31,6 +31,17 @@ features = fbank(wave)
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Streaming FBANK</td>
|
||||
<td><code>kaldifeat.FbankOptions</code></td>
|
||||
<td><code>kaldifeat.OnlineFbank</code></td>
|
||||
<td>
|
||||
See <a href="./kaldifeat/python/tests/test_fbank.py">
|
||||
./kaldifeat/python/tests/test_fbank.py
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>MFCC</td>
|
||||
<td><code>kaldifeat.MfccOptions</code></td>
|
||||
@ -45,6 +56,17 @@ features = mfcc(wave)
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Streaming MFCC</td>
|
||||
<td><code>kaldifeat.MfccOptions</code></td>
|
||||
<td><code>kaldifeat.OnlineMfcc</code></td>
|
||||
<td>
|
||||
See <a href="./kaldifeat/python/tests/test_mfcc.py">
|
||||
./kaldifeat/python/tests/test_mfcc.py
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>PLP</td>
|
||||
<td><code>kaldifeat.PlpOptions</code></td>
|
||||
@ -59,6 +81,17 @@ features = plp(wave)
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Streaming PLP</td>
|
||||
<td><code>kaldifeat.PlpOptions</code></td>
|
||||
<td><code>kaldifeat.OnlinePlp</code></td>
|
||||
<td>
|
||||
See <a href="./kaldifeat/python/tests/test_plp.py">
|
||||
./kaldifeat/python/tests/test_plp.py
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>Spectorgram</td>
|
||||
<td><code>kaldifeat.SpectrogramOptions</code></td>
|
||||
@ -88,6 +121,8 @@ The following kaldi-compatible commandline tools are implemented:
|
||||
|
||||
(**NOTE**: We will implement other types of features, e.g., Pitch, ivector, etc, soon.)
|
||||
|
||||
**HINT**: It supports also streaming feature extractors for Fbank, MFCC, and Plp.
|
||||
|
||||
# Usage
|
||||
|
||||
Let us first generate a test wave using sox:
|
||||
|
@ -9,6 +9,7 @@ set(kaldifeat_srcs
|
||||
feature-window.cc
|
||||
matrix-functions.cc
|
||||
mel-computations.cc
|
||||
online-feature.cc
|
||||
)
|
||||
|
||||
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
|
||||
@ -40,6 +41,7 @@ if(kaldifeat_BUILD_TESTS)
|
||||
# please sort the source files alphabetically
|
||||
set(test_srcs
|
||||
feature-window-test.cc
|
||||
online-feature-test.cc
|
||||
)
|
||||
|
||||
foreach(source IN LISTS test_srcs)
|
||||
|
@ -62,6 +62,10 @@ class OfflineFeatureTpl {
|
||||
int32_t Dim() const { return computer_.Dim(); }
|
||||
const Options &GetOptions() const { return computer_.GetOptions(); }
|
||||
|
||||
const FrameExtractionOptions &GetFrameOptions() const {
|
||||
return GetOptions().frame_opts;
|
||||
}
|
||||
|
||||
// Copy constructor.
|
||||
OfflineFeatureTpl(const OfflineFeatureTpl<F> &) = delete;
|
||||
OfflineFeatureTpl<F> &operator=(const OfflineFeatureTpl<F> &) = delete;
|
||||
|
@ -161,19 +161,20 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value) {
|
||||
#if 1
|
||||
return wave + rand_gauss * dither_value;
|
||||
#else
|
||||
// use in-place version of wave and change its to pointer type
|
||||
// use in-place version of wave and change it to pointer type
|
||||
wave_->add_(rand_gauss, dither_value);
|
||||
#endif
|
||||
}
|
||||
|
||||
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) {
|
||||
using namespace torch::indexing; // It imports: Slice, None // NOLINT
|
||||
if (preemph_coeff == 0.0f) return wave;
|
||||
|
||||
KALDIFEAT_ASSERT(preemph_coeff >= 0.0f && preemph_coeff <= 1.0f);
|
||||
|
||||
torch::Tensor ans = torch::empty_like(wave);
|
||||
|
||||
using torch::indexing::None;
|
||||
using torch::indexing::Slice;
|
||||
// right = wave[:, 1:]
|
||||
torch::Tensor right = wave.index({"...", Slice(1, None, None)});
|
||||
|
||||
@ -188,4 +189,59 @@ torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) {
|
||||
return ans;
|
||||
}
|
||||
|
||||
torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
|
||||
int32_t f, const FrameExtractionOptions &opts) {
|
||||
KALDIFEAT_ASSERT(sample_offset >= 0 && wave.numel() != 0);
|
||||
|
||||
int32_t frame_length = opts.WindowSize();
|
||||
int64_t num_samples = sample_offset + wave.numel();
|
||||
int64_t start_sample = FirstSampleOfFrame(f, opts);
|
||||
int64_t end_sample = start_sample + frame_length;
|
||||
|
||||
if (opts.snip_edges) {
|
||||
KALDIFEAT_ASSERT(start_sample >= sample_offset &&
|
||||
end_sample <= num_samples);
|
||||
} else {
|
||||
KALDIFEAT_ASSERT(sample_offset == 0 || start_sample >= sample_offset);
|
||||
}
|
||||
|
||||
// wave_start and wave_end are start and end indexes into 'wave', for the
|
||||
// piece of wave that we're trying to extract.
|
||||
int32_t wave_start = static_cast<int32_t>(start_sample - sample_offset);
|
||||
int32_t wave_end = wave_start + frame_length;
|
||||
|
||||
if (wave_start >= 0 && wave_end <= wave.numel()) {
|
||||
// the normal case -- no edge effects to consider.
|
||||
// return wave[wave_start:wave_end]
|
||||
return wave.index({torch::indexing::Slice(wave_start, wave_end)});
|
||||
} else {
|
||||
torch::Tensor window = torch::empty({frame_length}, torch::kFloat);
|
||||
auto p_window = window.accessor<float, 1>();
|
||||
auto p_wave = wave.accessor<float, 1>();
|
||||
|
||||
// Deal with any end effects by reflection, if needed. This code will only
|
||||
// be reached for about two frames per utterance, so we don't concern
|
||||
// ourselves excessively with efficiency.
|
||||
int32_t wave_dim = wave.numel();
|
||||
for (int32_t s = 0; s != frame_length; ++s) {
|
||||
int32_t s_in_wave = s + wave_start;
|
||||
while (s_in_wave < 0 || s_in_wave >= wave_dim) {
|
||||
// reflect around the beginning or end of the wave.
|
||||
// e.g. -1 -> 0, -2 -> 1.
|
||||
// dim -> dim - 1, dim + 1 -> dim - 2.
|
||||
// the code supports repeated reflections, although this
|
||||
// would only be needed in pathological cases.
|
||||
if (s_in_wave < 0) {
|
||||
s_in_wave = -s_in_wave - 1;
|
||||
} else {
|
||||
s_in_wave = 2 * wave_dim - 1 - s_in_wave;
|
||||
}
|
||||
}
|
||||
|
||||
p_window[s] = p_wave[s_in_wave];
|
||||
}
|
||||
return window;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
@ -44,7 +44,11 @@ struct FrameExtractionOptions {
|
||||
bool snip_edges = true;
|
||||
// bool allow_downsample = false;
|
||||
// bool allow_upsample = false;
|
||||
// int32_t max_feature_vectors = -1;
|
||||
|
||||
// Used for streaming feature extraction. It indicates the number
|
||||
// of feature frames to keep in the recycling vector. -1 means to
|
||||
// keep all feature frames.
|
||||
int32_t max_feature_vectors = -1;
|
||||
|
||||
int32_t WindowShift() const {
|
||||
return static_cast<int32_t>(samp_freq * 0.001f * frame_shift_ms);
|
||||
@ -71,7 +75,7 @@ struct FrameExtractionOptions {
|
||||
KALDIFEAT_PRINT(snip_edges);
|
||||
// KALDIFEAT_PRINT(allow_downsample);
|
||||
// KALDIFEAT_PRINT(allow_upsample);
|
||||
// KALDIFEAT_PRINT(max_feature_vectors);
|
||||
KALDIFEAT_PRINT(max_feature_vectors);
|
||||
#undef KALDIFEAT_PRINT
|
||||
return os.str();
|
||||
}
|
||||
@ -100,11 +104,11 @@ class FeatureWindowFunction {
|
||||
|
||||
@param [in] flush True if we are asserting that this number of samples
|
||||
is 'all there is', false if we expecting more data to possibly come in. This
|
||||
only makes a difference to the answer if opts.snips_edges
|
||||
== false. For offline feature extraction you always want flush ==
|
||||
true. In an online-decoding context, once you know (or decide)
|
||||
that no more data is coming in, you'd call it with flush == true at the end
|
||||
to flush out any remaining data.
|
||||
only makes a difference to the answer
|
||||
if opts.snips_edges== false. For offline feature extraction you always want
|
||||
flush == true. In an online-decoding context, once you know (or decide) that
|
||||
no more data is coming in, you'd call it with flush == true at the end to
|
||||
flush out any remaining data.
|
||||
*/
|
||||
int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts,
|
||||
bool flush = true);
|
||||
@ -133,6 +137,29 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value);
|
||||
|
||||
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave);
|
||||
|
||||
/*
|
||||
ExtractWindow() extracts "frame_length" samples from the given waveform.
|
||||
Note: This function only extracts "frame_length" samples
|
||||
from the input waveform, without any further processing.
|
||||
|
||||
@param [in] sample_offset If 'wave' is not the entire waveform, but
|
||||
part of it to the left has been discarded, then the
|
||||
number of samples prior to 'wave' that we have
|
||||
already discarded. Set this to zero if you are
|
||||
processing the entire waveform in one piece, or
|
||||
if you get 'no matching function' compilation
|
||||
errors when updating the code.
|
||||
@param [in] wave The waveform
|
||||
@param [in] f The frame index to be extracted, with
|
||||
0 <= f < NumFrames(sample_offset + wave.numel(), opts, true)
|
||||
@param [in] opts The options class to be used
|
||||
@return Return a tensor containing "frame_length" samples extracted from
|
||||
`wave`, without any further processing. Its shape is
|
||||
(1, frame_length).
|
||||
*/
|
||||
torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
|
||||
int32_t f, const FrameExtractionOptions &opts);
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
#endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_
|
||||
|
89
kaldifeat/csrc/online-feature-itf.h
Normal file
89
kaldifeat/csrc/online-feature-itf.h
Normal file
@ -0,0 +1,89 @@
|
||||
// kaldifeat/csrc/online-feature-itf.h
|
||||
//
|
||||
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
// This file is copied/modified from kaldi/src/itf/online-feature-itf.h
|
||||
|
||||
#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
|
||||
#define KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "torch/script.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
class OnlineFeatureInterface {
|
||||
public:
|
||||
virtual ~OnlineFeatureInterface() = default;
|
||||
|
||||
virtual int32_t Dim() const = 0; /// returns the feature dimension.
|
||||
//
|
||||
// Returns frame shift in seconds. Helps to estimate duration from frame
|
||||
// counts.
|
||||
virtual float FrameShiftInSeconds() const = 0;
|
||||
|
||||
/// Returns the total number of frames, since the start of the utterance, that
|
||||
/// are now available. In an online-decoding context, this will likely
|
||||
/// increase with time as more data becomes available.
|
||||
virtual int32_t NumFramesReady() const = 0;
|
||||
|
||||
/// Returns true if this is the last frame. Frame indices are zero-based, so
|
||||
/// the first frame is zero. IsLastFrame(-1) will return false, unless the
|
||||
/// file is empty (which is a case that I'm not sure all the code will handle,
|
||||
/// so be careful). This function may return false for some frame if we
|
||||
/// haven't yet decided to terminate decoding, but later true if we decide to
|
||||
/// terminate decoding. This function exists mainly to correctly handle end
|
||||
/// effects in feature extraction, and is not a mechanism to determine how
|
||||
/// many frames are in the decodable object (as it used to be, and for
|
||||
/// backward compatibility, still is, in the Decodable interface).
|
||||
virtual bool IsLastFrame(int32_t frame) const = 0;
|
||||
|
||||
/// Gets the feature vector for this frame. Before calling this for a given
|
||||
/// frame, it is assumed that you called NumFramesReady() and it returned a
|
||||
/// number greater than "frame". Otherwise this call will likely crash with
|
||||
/// an assert failure. This function is not declared const, in case there is
|
||||
/// some kind of caching going on, but most of the time it shouldn't modify
|
||||
/// the class.
|
||||
///
|
||||
/// The returned tensor has shape (1, Dim()).
|
||||
virtual torch::Tensor GetFrame(int32_t frame) = 0;
|
||||
|
||||
/// This is like GetFrame() but for a collection of frames. There is a
|
||||
/// default implementation that just gets the frames one by one, but it
|
||||
/// may be overridden for efficiency by child classes (since sometimes
|
||||
/// it's more efficient to do things in a batch).
|
||||
///
|
||||
/// The returned tensor has shape (frames.size(), Dim()).
|
||||
virtual std::vector<torch::Tensor> GetFrames(
|
||||
const std::vector<int32_t> &frames) {
|
||||
std::vector<torch::Tensor> features;
|
||||
features.reserve(frames.size());
|
||||
|
||||
for (auto i : frames) {
|
||||
torch::Tensor f = GetFrame(i);
|
||||
features.push_back(std::move(f));
|
||||
}
|
||||
return features;
|
||||
#if 0
|
||||
return torch::cat(features, /*dim*/ 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
/// This would be called from the application, when you get more wave data.
|
||||
/// Note: the sampling_rate is typically only provided so the code can assert
|
||||
/// that it matches the sampling rate expected in the options.
|
||||
virtual void AcceptWaveform(float sampling_rate,
|
||||
const torch::Tensor &waveform) = 0;
|
||||
|
||||
/// InputFinished() tells the class you won't be providing any
|
||||
/// more waveform. This will help flush out the last few frames
|
||||
/// of delta or LDA features (it will typically affect the return value
|
||||
/// of IsLastFrame.
|
||||
virtual void InputFinished() = 0;
|
||||
};
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
|
49
kaldifeat/csrc/online-feature-test.cc
Normal file
49
kaldifeat/csrc/online-feature-test.cc
Normal file
@ -0,0 +1,49 @@
|
||||
// kaldifeat/csrc/online-feature-test.h
|
||||
//
|
||||
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
#include "kaldifeat/csrc/online-feature.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
TEST(RecyclingVector, TestUnlimited) {
|
||||
RecyclingVector v(-1);
|
||||
constexpr int32_t N = 100;
|
||||
for (int32_t i = 0; i != N; ++i) {
|
||||
torch::Tensor t = torch::tensor({i, i + 1, i + 2});
|
||||
v.PushBack(t);
|
||||
}
|
||||
ASSERT_EQ(v.Size(), N);
|
||||
|
||||
for (int32_t i = 0; i != N; ++i) {
|
||||
torch::Tensor t = v.At(i);
|
||||
torch::Tensor expected = torch::tensor({i, i + 1, i + 2});
|
||||
EXPECT_TRUE(t.equal(expected));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(RecyclingVector, Testlimited) {
|
||||
constexpr int32_t K = 3;
|
||||
constexpr int32_t N = 10;
|
||||
RecyclingVector v(K);
|
||||
for (int32_t i = 0; i != N; ++i) {
|
||||
torch::Tensor t = torch::tensor({i, i + 1, i + 2});
|
||||
v.PushBack(t);
|
||||
}
|
||||
|
||||
ASSERT_EQ(v.Size(), N);
|
||||
|
||||
for (int32_t i = 0; i < N - K; ++i) {
|
||||
ASSERT_DEATH(v.At(i), "");
|
||||
}
|
||||
|
||||
for (int32_t i = N - K; i != N; ++i) {
|
||||
torch::Tensor t = v.At(i);
|
||||
torch::Tensor expected = torch::tensor({i, i + 1, i + 2});
|
||||
EXPECT_TRUE(t.equal(expected));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
133
kaldifeat/csrc/online-feature.cc
Normal file
133
kaldifeat/csrc/online-feature.cc
Normal file
@ -0,0 +1,133 @@
|
||||
// kaldifeat/csrc/online-feature.cc
|
||||
//
|
||||
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
// This file is copied/modified from kaldi/src/feat/online-feature.cc
|
||||
|
||||
#include "kaldifeat/csrc/online-feature.h"
|
||||
|
||||
#include "kaldifeat/csrc/feature-window.h"
|
||||
#include "kaldifeat/csrc/log.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
RecyclingVector::RecyclingVector(int32_t items_to_hold)
|
||||
: items_to_hold_(items_to_hold == 0 ? -1 : items_to_hold),
|
||||
first_available_index_(0) {}
|
||||
|
||||
torch::Tensor RecyclingVector::At(int32_t index) const {
|
||||
if (index < first_available_index_) {
|
||||
KALDIFEAT_ERR << "Attempted to retrieve feature vector that was "
|
||||
"already removed by the RecyclingVector (index = "
|
||||
<< index << "; "
|
||||
<< "first_available_index = " << first_available_index_
|
||||
<< "; "
|
||||
<< "size = " << Size() << ")";
|
||||
}
|
||||
// 'at' does size checking.
|
||||
return items_.at(index - first_available_index_);
|
||||
}
|
||||
|
||||
void RecyclingVector::PushBack(torch::Tensor item) {
|
||||
// Note: -1 is a larger number when treated as unsigned
|
||||
if (items_.size() == static_cast<size_t>(items_to_hold_)) {
|
||||
items_.pop_front();
|
||||
++first_available_index_;
|
||||
}
|
||||
items_.push_back(item);
|
||||
}
|
||||
|
||||
int32_t RecyclingVector::Size() const {
|
||||
return first_available_index_ + static_cast<int32_t>(items_.size());
|
||||
}
|
||||
|
||||
template <class C>
|
||||
OnlineGenericBaseFeature<C>::OnlineGenericBaseFeature(
|
||||
const typename C::Options &opts)
|
||||
: computer_(opts),
|
||||
window_function_(opts.frame_opts, opts.device),
|
||||
features_(opts.frame_opts.max_feature_vectors),
|
||||
input_finished_(false),
|
||||
waveform_offset_(0) {}
|
||||
|
||||
template <class C>
|
||||
void OnlineGenericBaseFeature<C>::AcceptWaveform(
|
||||
float sampling_rate, const torch::Tensor &original_waveform) {
|
||||
if (original_waveform.numel() == 0) return; // Nothing to do.
|
||||
|
||||
KALDIFEAT_ASSERT(original_waveform.dim() == 1);
|
||||
KALDIFEAT_ASSERT(sampling_rate == computer_.GetFrameOptions().samp_freq);
|
||||
|
||||
if (input_finished_)
|
||||
KALDIFEAT_ERR << "AcceptWaveform called after InputFinished() was called.";
|
||||
|
||||
if (waveform_remainder_.numel() == 0) {
|
||||
waveform_remainder_ = original_waveform;
|
||||
} else {
|
||||
waveform_remainder_ =
|
||||
torch::cat({waveform_remainder_, original_waveform}, /*dim*/ 0);
|
||||
}
|
||||
|
||||
ComputeFeatures();
|
||||
}
|
||||
|
||||
template <class C>
|
||||
void OnlineGenericBaseFeature<C>::InputFinished() {
|
||||
input_finished_ = true;
|
||||
ComputeFeatures();
|
||||
}
|
||||
|
||||
template <class C>
|
||||
void OnlineGenericBaseFeature<C>::ComputeFeatures() {
|
||||
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
|
||||
|
||||
int64_t num_samples_total = waveform_offset_ + waveform_remainder_.numel();
|
||||
int32_t num_frames_old = features_.Size();
|
||||
int32_t num_frames_new =
|
||||
NumFrames(num_samples_total, frame_opts, input_finished_);
|
||||
|
||||
KALDIFEAT_ASSERT(num_frames_new >= num_frames_old);
|
||||
|
||||
// note: this online feature-extraction code does not support VTLN.
|
||||
float vtln_warp = 1.0;
|
||||
|
||||
for (int32_t frame = num_frames_old; frame < num_frames_new; ++frame) {
|
||||
torch::Tensor window =
|
||||
ExtractWindow(waveform_offset_, waveform_remainder_, frame, frame_opts);
|
||||
|
||||
// TODO(fangjun): We can compute all feature frames at once
|
||||
torch::Tensor this_feature =
|
||||
computer_.ComputeFeatures(window.unsqueeze(0), vtln_warp);
|
||||
features_.PushBack(this_feature);
|
||||
}
|
||||
|
||||
// OK, we will now discard any portion of the signal that will not be
|
||||
// necessary to compute frames in the future.
|
||||
int64_t first_sample_of_next_frame =
|
||||
FirstSampleOfFrame(num_frames_new, frame_opts);
|
||||
int32_t samples_to_discard = first_sample_of_next_frame - waveform_offset_;
|
||||
if (samples_to_discard > 0) {
|
||||
// discard the leftmost part of the waveform that we no longer need.
|
||||
int32_t new_num_samples = waveform_remainder_.numel() - samples_to_discard;
|
||||
if (new_num_samples <= 0) {
|
||||
// odd, but we'll try to handle it.
|
||||
waveform_offset_ += waveform_remainder_.numel();
|
||||
waveform_remainder_.resize_({0});
|
||||
} else {
|
||||
using torch::indexing::None;
|
||||
using torch::indexing::Slice;
|
||||
|
||||
waveform_remainder_ =
|
||||
waveform_remainder_.index({Slice(samples_to_discard, None)});
|
||||
|
||||
waveform_offset_ += samples_to_discard;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// instantiate the templates defined here for MFCC, PLP and filterbank classes.
|
||||
template class OnlineGenericBaseFeature<Mfcc>;
|
||||
template class OnlineGenericBaseFeature<Plp>;
|
||||
template class OnlineGenericBaseFeature<Fbank>;
|
||||
|
||||
} // namespace kaldifeat
|
127
kaldifeat/csrc/online-feature.h
Normal file
127
kaldifeat/csrc/online-feature.h
Normal file
@ -0,0 +1,127 @@
|
||||
// kaldifeat/csrc/online-feature.h
|
||||
//
|
||||
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
// This file is copied/modified from kaldi/src/feat/online-feature.h
|
||||
|
||||
#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_H_
|
||||
#define KALDIFEAT_CSRC_ONLINE_FEATURE_H_
|
||||
|
||||
#include <deque>
|
||||
|
||||
#include "kaldifeat/csrc/feature-fbank.h"
|
||||
#include "kaldifeat/csrc/feature-mfcc.h"
|
||||
#include "kaldifeat/csrc/feature-plp.h"
|
||||
#include "kaldifeat/csrc/feature-window.h"
|
||||
#include "kaldifeat/csrc/online-feature-itf.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
/// This class serves as a storage for feature vectors with an option to limit
|
||||
/// the memory usage by removing old elements. The deleted frames indices are
|
||||
/// "remembered" so that regardless of the MAX_ITEMS setting, the user always
|
||||
/// provides the indices as if no deletion was being performed.
|
||||
/// This is useful when processing very long recordings which would otherwise
|
||||
/// cause the memory to eventually blow up when the features are not being
|
||||
/// removed.
|
||||
class RecyclingVector {
|
||||
public:
|
||||
/// By default it does not remove any elements.
|
||||
explicit RecyclingVector(int32_t items_to_hold = -1);
|
||||
|
||||
~RecyclingVector() = default;
|
||||
RecyclingVector(const RecyclingVector &) = delete;
|
||||
RecyclingVector &operator=(const RecyclingVector &) = delete;
|
||||
|
||||
torch::Tensor At(int32_t index) const;
|
||||
|
||||
void PushBack(torch::Tensor item);
|
||||
|
||||
/// This method returns the size as if no "recycling" had happened,
|
||||
/// i.e. equivalent to the number of times the PushBack method has been
|
||||
/// called.
|
||||
int32_t Size() const;
|
||||
|
||||
private:
|
||||
std::deque<torch::Tensor> items_;
|
||||
int32_t items_to_hold_;
|
||||
int32_t first_available_index_;
|
||||
};
|
||||
|
||||
/// This is a templated class for online feature extraction;
|
||||
/// it's templated on a class like MfccComputer or PlpComputer
|
||||
/// that does the basic feature extraction.
|
||||
template <class C>
|
||||
class OnlineGenericBaseFeature : public OnlineFeatureInterface {
|
||||
public:
|
||||
// Constructor from options class
|
||||
explicit OnlineGenericBaseFeature(const typename C::Options &opts);
|
||||
|
||||
int32_t Dim() const override { return computer_.Dim(); }
|
||||
|
||||
float FrameShiftInSeconds() const override {
|
||||
return computer_.GetFrameOptions().frame_shift_ms / 1000.0f;
|
||||
}
|
||||
|
||||
int32_t NumFramesReady() const override { return features_.Size(); }
|
||||
|
||||
// Note: IsLastFrame() will only ever return true if you have called
|
||||
// InputFinished() (and this frame is the last frame).
|
||||
bool IsLastFrame(int32_t frame) const override {
|
||||
return input_finished_ && frame == NumFramesReady() - 1;
|
||||
}
|
||||
|
||||
torch::Tensor GetFrame(int32_t frame) override { return features_.At(frame); }
|
||||
|
||||
// This would be called from the application, when you get
|
||||
// more wave data. Note: the sampling_rate is only provided so
|
||||
// the code can assert that it matches the sampling rate
|
||||
// expected in the options.
|
||||
void AcceptWaveform(float sampling_rate,
|
||||
const torch::Tensor &waveform) override;
|
||||
|
||||
// InputFinished() tells the class you won't be providing any
|
||||
// more waveform. This will help flush out the last frame or two
|
||||
// of features, in the case where snip-edges == false; it also
|
||||
// affects the return value of IsLastFrame().
|
||||
void InputFinished() override;
|
||||
|
||||
private:
|
||||
// This function computes any additional feature frames that it is possible to
|
||||
// compute from 'waveform_remainder_', which at this point may contain more
|
||||
// than just a remainder-sized quantity (because AcceptWaveform() appends to
|
||||
// waveform_remainder_ before calling this function). It adds these feature
|
||||
// frames to features_, and shifts off any now-unneeded samples of input from
|
||||
// waveform_remainder_ while incrementing waveform_offset_ by the same amount.
|
||||
void ComputeFeatures();
|
||||
|
||||
C computer_; // class that does the MFCC or PLP or filterbank computation
|
||||
|
||||
FeatureWindowFunction window_function_;
|
||||
|
||||
// features_ is the Mfcc or Plp or Fbank features that we have already
|
||||
// computed.
|
||||
|
||||
RecyclingVector features_;
|
||||
|
||||
// True if the user has called "InputFinished()"
|
||||
bool input_finished_;
|
||||
|
||||
// waveform_offset_ is the number of samples of waveform that we have
|
||||
// already discarded, i.e. that were prior to 'waveform_remainder_'.
|
||||
int64_t waveform_offset_;
|
||||
|
||||
// waveform_remainder_ is a short piece of waveform that we may need to keep
|
||||
// after extracting all the whole frames we can (whatever length of feature
|
||||
// will be required for the next phase of computation).
|
||||
// It is a 1-D tensor
|
||||
torch::Tensor waveform_remainder_;
|
||||
};
|
||||
|
||||
using OnlineMfcc = OnlineGenericBaseFeature<Mfcc>;
|
||||
using OnlinePlp = OnlineGenericBaseFeature<Plp>;
|
||||
using OnlineFbank = OnlineGenericBaseFeature<Fbank>;
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_H_
|
@ -7,6 +7,7 @@ pybind11_add_module(_kaldifeat
|
||||
feature-window.cc
|
||||
kaldifeat.cc
|
||||
mel-computations.cc
|
||||
online-feature.cc
|
||||
utils.cc
|
||||
)
|
||||
target_link_libraries(_kaldifeat PRIVATE kaldifeat_core)
|
||||
|
@ -25,6 +25,7 @@ static void PybindFrameExtractionOptions(py::module &m) {
|
||||
.def_readwrite("round_to_power_of_two", &PyClass::round_to_power_of_two)
|
||||
.def_readwrite("blackman_coeff", &PyClass::blackman_coeff)
|
||||
.def_readwrite("snip_edges", &PyClass::snip_edges)
|
||||
.def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)
|
||||
.def("as_dict",
|
||||
[](const PyClass &self) -> py::dict { return AsDict(self); })
|
||||
.def_static("from_dict",
|
||||
@ -35,8 +36,6 @@ static void PybindFrameExtractionOptions(py::module &m) {
|
||||
.def_readwrite("allow_downsample",
|
||||
&PyClass::allow_downsample)
|
||||
.def_readwrite("allow_upsample", &PyClass::allow_upsample)
|
||||
.def_readwrite("max_feature_vectors",
|
||||
&PyClass::max_feature_vectors)
|
||||
#endif
|
||||
.def("__str__",
|
||||
[](const PyClass &self) -> std::string { return self.ToString(); })
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include "kaldifeat/python/csrc/feature-spectrogram.h"
|
||||
#include "kaldifeat/python/csrc/feature-window.h"
|
||||
#include "kaldifeat/python/csrc/mel-computations.h"
|
||||
#include "kaldifeat/python/csrc/online-feature.h"
|
||||
#include "torch/torch.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
@ -24,6 +25,7 @@ PYBIND11_MODULE(_kaldifeat, m) {
|
||||
PybindFeatureMfcc(m);
|
||||
PybindFeaturePlp(m);
|
||||
PybindFeatureSpectrogram(m);
|
||||
PybindOnlineFeature(m);
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
37
kaldifeat/python/csrc/online-feature.cc
Normal file
37
kaldifeat/python/csrc/online-feature.cc
Normal file
@ -0,0 +1,37 @@
|
||||
// kaldifeat/python/csrc/online-feature.cc
|
||||
//
|
||||
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
#include "kaldifeat/python/csrc/online-feature.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "kaldifeat/csrc/online-feature.h"
|
||||
namespace kaldifeat {
|
||||
|
||||
template <typename C>
|
||||
void PybindOnlineFeatureTpl(py::module &m, const std::string &class_name,
|
||||
const std::string &class_help_doc = "") {
|
||||
using PyClass = OnlineGenericBaseFeature<C>;
|
||||
using Options = typename C::Options;
|
||||
py::class_<PyClass>(m, class_name.c_str(), class_help_doc.c_str())
|
||||
.def(py::init<const Options &>(), py::arg("opts"))
|
||||
.def_property_readonly("dim", &PyClass::Dim)
|
||||
.def_property_readonly("frame_shift_in_seconds",
|
||||
&PyClass::FrameShiftInSeconds)
|
||||
.def_property_readonly("num_frames_ready", &PyClass::NumFramesReady)
|
||||
.def("is_last_frame", &PyClass::IsLastFrame, py::arg("frame"))
|
||||
.def("get_frame", &PyClass::GetFrame, py::arg("frame"))
|
||||
.def("get_frames", &PyClass::GetFrames, py::arg("frames"))
|
||||
.def("accept_waveform", &PyClass::AcceptWaveform,
|
||||
py::arg("sampling_rate"), py::arg("waveform"))
|
||||
.def("input_finished", &PyClass::InputFinished);
|
||||
}
|
||||
|
||||
void PybindOnlineFeature(py::module &m) {
|
||||
PybindOnlineFeatureTpl<Mfcc>(m, "OnlineMfcc");
|
||||
PybindOnlineFeatureTpl<Fbank>(m, "OnlineFbank");
|
||||
PybindOnlineFeatureTpl<Plp>(m, "OnlinePlp");
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
16
kaldifeat/python/csrc/online-feature.h
Normal file
16
kaldifeat/python/csrc/online-feature.h
Normal file
@ -0,0 +1,16 @@
|
||||
// kaldifeat/python/csrc/online-feature.h
|
||||
//
|
||||
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
#ifndef KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_
|
||||
#define KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_
|
||||
|
||||
#include "kaldifeat/python/csrc/kaldifeat.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
void PybindOnlineFeature(py::module &m);
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
#endif // KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_
|
@ -30,6 +30,7 @@ FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict) {
|
||||
FROM_DICT(bool_, round_to_power_of_two);
|
||||
FROM_DICT(float_, blackman_coeff);
|
||||
FROM_DICT(bool_, snip_edges);
|
||||
FROM_DICT(int_, max_feature_vectors);
|
||||
|
||||
return opts;
|
||||
}
|
||||
@ -47,6 +48,7 @@ py::dict AsDict(const FrameExtractionOptions &opts) {
|
||||
AS_DICT(round_to_power_of_two);
|
||||
AS_DICT(blackman_coeff);
|
||||
AS_DICT(snip_edges);
|
||||
AS_DICT(max_feature_vectors);
|
||||
|
||||
return dict;
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ from _kaldifeat import (
|
||||
SpectrogramOptions,
|
||||
)
|
||||
|
||||
from .fbank import Fbank
|
||||
from .mfcc import Mfcc
|
||||
from .plp import Plp
|
||||
from .fbank import Fbank, OnlineFbank
|
||||
from .mfcc import Mfcc, OnlineMfcc
|
||||
from .plp import OnlinePlp, Plp
|
||||
from .spectrogram import Spectrogram
|
||||
|
@ -4,9 +4,20 @@
|
||||
import _kaldifeat
|
||||
|
||||
from .offline_feature import OfflineFeature
|
||||
from .online_feature import OnlineFeature
|
||||
|
||||
|
||||
class Fbank(OfflineFeature):
|
||||
def __init__(self, opts: _kaldifeat.FbankOptions):
|
||||
super().__init__(opts)
|
||||
self.computer = _kaldifeat.Fbank(opts)
|
||||
|
||||
|
||||
class OnlineFbank(OnlineFeature):
|
||||
def __init__(self, opts: _kaldifeat.FbankOptions):
|
||||
super().__init__(opts)
|
||||
self.computer = _kaldifeat.OnlineFbank(opts)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.opts = _kaldifeat.FbankOptions.from_dict(state)
|
||||
self.computer = _kaldifeat.OnlineFbank(self.opts)
|
||||
|
@ -4,9 +4,20 @@
|
||||
import _kaldifeat
|
||||
|
||||
from .offline_feature import OfflineFeature
|
||||
from .online_feature import OnlineFeature
|
||||
|
||||
|
||||
class Mfcc(OfflineFeature):
|
||||
def __init__(self, opts: _kaldifeat.MfccOptions):
|
||||
super().__init__(opts)
|
||||
self.computer = _kaldifeat.Mfcc(opts)
|
||||
|
||||
|
||||
class OnlineMfcc(OnlineFeature):
|
||||
def __init__(self, opts: _kaldifeat.MfccOptions):
|
||||
super().__init__(opts)
|
||||
self.computer = _kaldifeat.OnlineMfcc(opts)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.opts = _kaldifeat.MfccOptions.from_dict(state)
|
||||
self.computer = _kaldifeat.OnlineMfcc(self.opts)
|
||||
|
95
kaldifeat/python/kaldifeat/online_feature.py
Normal file
95
kaldifeat/python/kaldifeat/online_feature.py
Normal file
@ -0,0 +1,95 @@
|
||||
# Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class OnlineFeature(object):
|
||||
"""Offline feature is a base class of other feature computers,
|
||||
e.g., Fbank, Mfcc.
|
||||
|
||||
This class has two fields:
|
||||
|
||||
(1) opts. It contains the options for the feature computer.
|
||||
(2) computer. The actual feature computer. It should be
|
||||
instantiated by subclasses.
|
||||
|
||||
Caution:
|
||||
It supports only CPU at present.
|
||||
"""
|
||||
|
||||
def __init__(self, opts):
|
||||
assert opts.device.type == "cpu"
|
||||
|
||||
self.opts = opts
|
||||
|
||||
# self.computer is expected to be set by subclasses
|
||||
self.computer = None
|
||||
|
||||
@property
|
||||
def num_frames_ready(self) -> int:
|
||||
"""Return the number of ready frames.
|
||||
|
||||
It can be updated by :method:`accept_waveform`.
|
||||
|
||||
Note:
|
||||
If you set ``opts.frame_opts.max_feature_vectors``, then
|
||||
the valid frame indexes are in the range.
|
||||
``[num_frames_ready - max_feature_vectors, num_frames_ready)``
|
||||
|
||||
If you leave ``opts.frame_opts.max_feature_vectors`` to its default
|
||||
value, then the range is ``[0, num_frames_ready)``
|
||||
"""
|
||||
return self.computer.num_frames_ready
|
||||
|
||||
def is_last_frame(self, frame: int) -> bool:
|
||||
"""Return True if the given frame is the last frame."""
|
||||
return self.computer.is_last_frame(frame)
|
||||
|
||||
def get_frame(self, frame: int) -> torch.Tensor:
|
||||
"""Get the frame by its index.
|
||||
Args:
|
||||
frame:
|
||||
The frame index. If ``opts.frame_opts.max_feature_vectors`` is
|
||||
-1, then its valid values are in the range
|
||||
``[0, num_frames_ready)``. Otherwise, the range is
|
||||
``[num_frames_ready - max_feature_vectors, num_frames_ready)``.
|
||||
Returns:
|
||||
Return a 2-D tensor with shape ``(1, feature_dim)``
|
||||
"""
|
||||
return self.computer.get_frame(frame)
|
||||
|
||||
def get_frames(self, frames: List[int]) -> List[torch.Tensor]:
|
||||
"""Get frames at the given frame indexes.
|
||||
Args:
|
||||
frames:
|
||||
Frames whose indexes are in this list are returned.
|
||||
Returns:
|
||||
Return a list of feature frames at the given indexes.
|
||||
"""
|
||||
return self.computer.get_frames(frames)
|
||||
|
||||
def accept_waveform(
|
||||
self, sampling_rate: float, waveform: torch.Tensor
|
||||
) -> None:
|
||||
"""Send audio samples to the extractor.
|
||||
Args:
|
||||
sampling_rate:
|
||||
The sampling rate of the given audio samples. It has to be equal
|
||||
to ``opts.frame_opts.samp_freq``.
|
||||
waveform:
|
||||
A 1-D tensor of shape (num_samples,). Its dtype is torch.float32
|
||||
and has to be on CPU.
|
||||
"""
|
||||
self.computer.accept_waveform(sampling_rate, waveform)
|
||||
|
||||
def input_finished(self) -> None:
|
||||
"""Tell the extractor that no more audio samples will be available.
|
||||
After calling this function, you cannot invoke ``accept_waveform``
|
||||
again.
|
||||
"""
|
||||
self.computer.input_finished()
|
||||
|
||||
def __getstate__(self):
|
||||
return self.opts.as_dict()
|
@ -4,9 +4,20 @@
|
||||
import _kaldifeat
|
||||
|
||||
from .offline_feature import OfflineFeature
|
||||
from .online_feature import OnlineFeature
|
||||
|
||||
|
||||
class Plp(OfflineFeature):
|
||||
def __init__(self, opts: _kaldifeat.PlpOptions):
|
||||
super().__init__(opts)
|
||||
self.computer = _kaldifeat.Plp(opts)
|
||||
|
||||
|
||||
class OnlinePlp(OnlineFeature):
|
||||
def __init__(self, opts: _kaldifeat.PlpOptions):
|
||||
super().__init__(opts)
|
||||
self.computer = _kaldifeat.OnlinePlp(opts)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.opts = _kaldifeat.PlpOptions.from_dict(state)
|
||||
self.computer = _kaldifeat.OnlinePlp(self.opts)
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
@ -13,30 +13,88 @@ import kaldifeat
|
||||
cur_dir = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
def test_online_fbank(
|
||||
opts: kaldifeat.FbankOptions,
|
||||
wave: torch.Tensor,
|
||||
cpu_features: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
opts:
|
||||
The options to create the online fbank extractor.
|
||||
wave:
|
||||
The input 1-D waveform.
|
||||
cpu_features:
|
||||
The groud truth features that are computed offline
|
||||
"""
|
||||
online_fbank = kaldifeat.OnlineFbank(opts)
|
||||
|
||||
num_processed_frames = 0
|
||||
i = 0 # current sample index to feed
|
||||
while not online_fbank.is_last_frame(num_processed_frames - 1):
|
||||
while num_processed_frames < online_fbank.num_frames_ready:
|
||||
# There are new frames to be processed
|
||||
frame = online_fbank.get_frame(num_processed_frames)
|
||||
assert torch.allclose(
|
||||
frame.squeeze(0), cpu_features[num_processed_frames]
|
||||
)
|
||||
num_processed_frames += 1
|
||||
|
||||
# Simulate streaming . Send a random number of audio samples
|
||||
# to the extractor
|
||||
num_samples = torch.randint(300, 1000, (1,)).item()
|
||||
|
||||
samples = wave[i : (i + num_samples)] # noqa
|
||||
i += num_samples
|
||||
if len(samples) == 0:
|
||||
online_fbank.input_finished()
|
||||
continue
|
||||
|
||||
online_fbank.accept_waveform(16000, samples)
|
||||
|
||||
assert num_processed_frames == online_fbank.num_frames_ready
|
||||
assert num_processed_frames == cpu_features.size(0)
|
||||
|
||||
|
||||
def test_fbank_default():
|
||||
print("=====test_fbank_default=====")
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test.txt")
|
||||
|
||||
cpu_features = None
|
||||
for device in get_devices():
|
||||
print("device", device)
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename)
|
||||
|
||||
features = fbank(wave)
|
||||
assert features.device.type == "cpu"
|
||||
gt = read_ark_txt(cur_dir / "test_data/test.txt")
|
||||
assert torch.allclose(features, gt, rtol=1e-1)
|
||||
if cpu_features is None:
|
||||
cpu_features = features
|
||||
|
||||
wave = wave.to(device)
|
||||
features = fbank(wave)
|
||||
features = fbank(wave.to(device))
|
||||
assert features.device == device
|
||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||
|
||||
# Now for online fbank
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.max_feature_vectors = 100
|
||||
|
||||
test_online_fbank(opts, wave, cpu_features)
|
||||
|
||||
|
||||
def test_fbank_htk():
|
||||
print("=====test_fbank_htk=====")
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
|
||||
|
||||
cpu_features = None
|
||||
for device in get_devices():
|
||||
print("device", device)
|
||||
opts = kaldifeat.FbankOptions()
|
||||
@ -46,22 +104,32 @@ def test_fbank_htk():
|
||||
opts.htk_compat = True
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename)
|
||||
|
||||
features = fbank(wave)
|
||||
assert features.device.type == "cpu"
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
|
||||
assert torch.allclose(features, gt, rtol=1e-1)
|
||||
if cpu_features is None:
|
||||
cpu_features = features
|
||||
|
||||
wave = wave.to(device)
|
||||
features = fbank(wave)
|
||||
features = fbank(wave.to(device))
|
||||
assert features.device == device
|
||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
opts.use_energy = True
|
||||
opts.htk_compat = True
|
||||
|
||||
test_online_fbank(opts, wave, cpu_features)
|
||||
|
||||
|
||||
def test_fbank_with_energy():
|
||||
print("=====test_fbank_with_energy=====")
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
|
||||
|
||||
cpu_features = None
|
||||
for device in get_devices():
|
||||
print("device", device)
|
||||
opts = kaldifeat.FbankOptions()
|
||||
@ -70,22 +138,31 @@ def test_fbank_with_energy():
|
||||
opts.use_energy = True
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename)
|
||||
|
||||
features = fbank(wave)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
|
||||
assert torch.allclose(features, gt, rtol=1e-1)
|
||||
assert features.device.type == "cpu"
|
||||
if cpu_features is None:
|
||||
cpu_features = features
|
||||
|
||||
wave = wave.to(device)
|
||||
features = fbank(wave)
|
||||
features = fbank(wave.to(device))
|
||||
assert features.device == device
|
||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
opts.use_energy = True
|
||||
|
||||
test_online_fbank(opts, wave, cpu_features)
|
||||
|
||||
|
||||
def test_fbank_40_bins():
|
||||
print("=====test_fbank_40_bins=====")
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
|
||||
|
||||
cpu_features = None
|
||||
for device in get_devices():
|
||||
print("device", device)
|
||||
opts = kaldifeat.FbankOptions()
|
||||
@ -94,22 +171,31 @@ def test_fbank_40_bins():
|
||||
opts.mel_opts.num_bins = 40
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename)
|
||||
|
||||
features = fbank(wave)
|
||||
assert features.device.type == "cpu"
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
|
||||
assert torch.allclose(features, gt, rtol=1e-1)
|
||||
if cpu_features is None:
|
||||
cpu_features = features
|
||||
|
||||
wave = wave.to(device)
|
||||
features = fbank(wave)
|
||||
features = fbank(wave.to(device))
|
||||
assert features.device == device
|
||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
opts.mel_opts.num_bins = 40
|
||||
|
||||
test_online_fbank(opts, wave, cpu_features)
|
||||
|
||||
|
||||
def test_fbank_40_bins_no_snip_edges():
|
||||
print("=====test_fbank_40_bins_no_snip_edges=====")
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
|
||||
|
||||
cpu_features = None
|
||||
for device in get_devices():
|
||||
print("device", device)
|
||||
opts = kaldifeat.FbankOptions()
|
||||
@ -119,19 +205,24 @@ def test_fbank_40_bins_no_snip_edges():
|
||||
opts.frame_opts.snip_edges = False
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename)
|
||||
|
||||
features = fbank(wave)
|
||||
assert features.device.type == "cpu"
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
|
||||
assert torch.allclose(features, gt, rtol=1e-1)
|
||||
if cpu_features is None:
|
||||
cpu_features = features
|
||||
|
||||
wave = wave.to(device)
|
||||
features = fbank(wave)
|
||||
features = fbank(wave.to(device))
|
||||
assert features.device == device
|
||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
opts.mel_opts.num_bins = 40
|
||||
opts.frame_opts.snip_edges = False
|
||||
|
||||
test_online_fbank(opts, wave, cpu_features)
|
||||
|
||||
|
||||
def test_fbank_chunk():
|
||||
print("=====test_fbank_chunk=====")
|
||||
@ -223,6 +314,16 @@ def test_pickle():
|
||||
|
||||
assert str(fbank.opts) == str(fbank2.opts)
|
||||
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.use_energy = True
|
||||
opts.use_power = False
|
||||
|
||||
fbank = kaldifeat.OnlineFbank(opts)
|
||||
data = pickle.dumps(fbank)
|
||||
fbank2 = pickle.loads(data)
|
||||
|
||||
assert str(fbank.opts) == str(fbank2.opts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_fbank_default()
|
||||
|
@ -13,24 +13,82 @@ import kaldifeat
|
||||
cur_dir = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
def test_online_mfcc(
|
||||
opts: kaldifeat.MfccOptions,
|
||||
wave: torch.Tensor,
|
||||
cpu_features: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
opts:
|
||||
The options to create the online mfcc extractor.
|
||||
wave:
|
||||
The input 1-D waveform.
|
||||
cpu_features:
|
||||
The groud truth features that are computed offline
|
||||
"""
|
||||
online_mfcc = kaldifeat.OnlineMfcc(opts)
|
||||
|
||||
num_processed_frames = 0
|
||||
i = 0 # current sample index to feed
|
||||
while not online_mfcc.is_last_frame(num_processed_frames - 1):
|
||||
while num_processed_frames < online_mfcc.num_frames_ready:
|
||||
# There are new frames to be processed
|
||||
frame = online_mfcc.get_frame(num_processed_frames)
|
||||
assert torch.allclose(
|
||||
frame.squeeze(0), cpu_features[num_processed_frames], atol=1e-3
|
||||
)
|
||||
num_processed_frames += 1
|
||||
|
||||
# Simulate streaming . Send a random number of audio samples
|
||||
# to the extractor
|
||||
num_samples = torch.randint(300, 1000, (1,)).item()
|
||||
|
||||
samples = wave[i : (i + num_samples)] # noqa
|
||||
i += num_samples
|
||||
if len(samples) == 0:
|
||||
online_mfcc.input_finished()
|
||||
continue
|
||||
|
||||
online_mfcc.accept_waveform(16000, samples)
|
||||
|
||||
assert num_processed_frames == online_mfcc.num_frames_ready
|
||||
assert num_processed_frames == cpu_features.size(0)
|
||||
|
||||
|
||||
def test_mfcc_default():
|
||||
print("=====test_mfcc_default=====")
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-mfcc.txt")
|
||||
|
||||
cpu_features = None
|
||||
for device in get_devices():
|
||||
print("device", device)
|
||||
opts = kaldifeat.MfccOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
mfcc = kaldifeat.Mfcc(opts)
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename).to(device)
|
||||
|
||||
features = mfcc(wave)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-mfcc.txt")
|
||||
features = mfcc(wave.to(device))
|
||||
if device.type == "cpu":
|
||||
cpu_features = features
|
||||
|
||||
assert torch.allclose(features.cpu(), gt, atol=1e-1)
|
||||
|
||||
opts = kaldifeat.MfccOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
|
||||
test_online_mfcc(opts, wave, cpu_features)
|
||||
|
||||
|
||||
def test_mfcc_no_snip_edges():
|
||||
print("=====test_mfcc_no_snip_edges=====")
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-mfcc-no-snip-edges.txt")
|
||||
|
||||
cpu_features = None
|
||||
for device in get_devices():
|
||||
print("device", device)
|
||||
opts = kaldifeat.MfccOptions()
|
||||
@ -39,13 +97,19 @@ def test_mfcc_no_snip_edges():
|
||||
opts.frame_opts.snip_edges = False
|
||||
|
||||
mfcc = kaldifeat.Mfcc(opts)
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename).to(device)
|
||||
|
||||
features = mfcc(wave)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-mfcc-no-snip-edges.txt")
|
||||
features = mfcc(wave.to(device))
|
||||
if device.type == "cpu":
|
||||
cpu_features = features
|
||||
|
||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||
|
||||
opts = kaldifeat.MfccOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
|
||||
test_online_mfcc(opts, wave, cpu_features)
|
||||
|
||||
|
||||
def test_pickle():
|
||||
for device in get_devices():
|
||||
@ -60,6 +124,16 @@ def test_pickle():
|
||||
|
||||
assert str(mfcc.opts) == str(mfcc2.opts)
|
||||
|
||||
opts = kaldifeat.MfccOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
|
||||
mfcc = kaldifeat.OnlineMfcc(opts)
|
||||
data = pickle.dumps(mfcc)
|
||||
mfcc2 = pickle.loads(data)
|
||||
|
||||
assert str(mfcc.opts) == str(mfcc2.opts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mfcc_default()
|
||||
|
@ -13,24 +13,82 @@ import kaldifeat
|
||||
cur_dir = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
def test_online_plp(
|
||||
opts: kaldifeat.PlpOptions,
|
||||
wave: torch.Tensor,
|
||||
cpu_features: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
opts:
|
||||
The options to create the online plp extractor.
|
||||
wave:
|
||||
The input 1-D waveform.
|
||||
cpu_features:
|
||||
The groud truth features that are computed offline
|
||||
"""
|
||||
online_plp = kaldifeat.OnlinePlp(opts)
|
||||
|
||||
num_processed_frames = 0
|
||||
i = 0 # current sample index to feed
|
||||
while not online_plp.is_last_frame(num_processed_frames - 1):
|
||||
while num_processed_frames < online_plp.num_frames_ready:
|
||||
# There are new frames to be processed
|
||||
frame = online_plp.get_frame(num_processed_frames)
|
||||
assert torch.allclose(
|
||||
frame.squeeze(0), cpu_features[num_processed_frames], atol=1e-3
|
||||
)
|
||||
num_processed_frames += 1
|
||||
|
||||
# Simulate streaming . Send a random number of audio samples
|
||||
# to the extractor
|
||||
num_samples = torch.randint(300, 1000, (1,)).item()
|
||||
|
||||
samples = wave[i : (i + num_samples)] # noqa
|
||||
i += num_samples
|
||||
if len(samples) == 0:
|
||||
online_plp.input_finished()
|
||||
continue
|
||||
|
||||
online_plp.accept_waveform(16000, samples)
|
||||
|
||||
assert num_processed_frames == online_plp.num_frames_ready
|
||||
assert num_processed_frames == cpu_features.size(0)
|
||||
|
||||
|
||||
def test_plp_default():
|
||||
print("=====test_plp_default=====")
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-plp.txt")
|
||||
|
||||
cpu_features = None
|
||||
for device in get_devices():
|
||||
print("device", device)
|
||||
opts = kaldifeat.PlpOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
opts.device = device
|
||||
plp = kaldifeat.Plp(opts)
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename).to(device)
|
||||
|
||||
features = plp(wave)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-plp.txt")
|
||||
features = plp(wave.to(device))
|
||||
if device.type == "cpu":
|
||||
cpu_features = features
|
||||
|
||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||
|
||||
opts = kaldifeat.PlpOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
|
||||
test_online_plp(opts, wave, cpu_features)
|
||||
|
||||
|
||||
def test_plp_no_snip_edges():
|
||||
print("=====test_plp_no_snip_edges=====")
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-plp-no-snip-edges.txt")
|
||||
|
||||
cpu_features = None
|
||||
for device in get_devices():
|
||||
print("device", device)
|
||||
opts = kaldifeat.PlpOptions()
|
||||
@ -39,16 +97,26 @@ def test_plp_no_snip_edges():
|
||||
opts.frame_opts.snip_edges = False
|
||||
|
||||
plp = kaldifeat.Plp(opts)
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename).to(device)
|
||||
|
||||
features = plp(wave)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-plp-no-snip-edges.txt")
|
||||
features = plp(wave.to(device))
|
||||
if device.type == "cpu":
|
||||
cpu_features = features
|
||||
assert torch.allclose(features.cpu(), gt, atol=1e-1)
|
||||
|
||||
opts = kaldifeat.PlpOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
|
||||
test_online_plp(opts, wave, cpu_features)
|
||||
|
||||
|
||||
def test_plp_htk_10_ceps():
|
||||
print("=====test_plp_htk_10_ceps=====")
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-plp-htk-10-ceps.txt")
|
||||
|
||||
cpu_features = None
|
||||
for device in get_devices():
|
||||
print("device", device)
|
||||
opts = kaldifeat.PlpOptions()
|
||||
@ -58,13 +126,19 @@ def test_plp_htk_10_ceps():
|
||||
opts.frame_opts.dither = 0
|
||||
|
||||
plp = kaldifeat.Plp(opts)
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename).to(device)
|
||||
|
||||
features = plp(wave)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-plp-htk-10-ceps.txt")
|
||||
features = plp(wave.to(device))
|
||||
if device.type == "cpu":
|
||||
cpu_features = features
|
||||
assert torch.allclose(features.cpu(), gt, atol=1e-1)
|
||||
|
||||
opts = kaldifeat.PlpOptions()
|
||||
opts.htk_compat = True
|
||||
opts.num_ceps = 10
|
||||
opts.frame_opts.dither = 0
|
||||
|
||||
test_online_plp(opts, wave, cpu_features)
|
||||
|
||||
|
||||
def test_pickle():
|
||||
for device in get_devices():
|
||||
@ -79,6 +153,16 @@ def test_pickle():
|
||||
|
||||
assert str(plp.opts) == str(plp2.opts)
|
||||
|
||||
opts = kaldifeat.PlpOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
|
||||
plp = kaldifeat.OnlinePlp(opts)
|
||||
data = pickle.dumps(plp)
|
||||
plp2 = pickle.loads(data)
|
||||
|
||||
assert str(plp.opts) == str(plp2.opts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_plp_default()
|
||||
|
Loading…
x
Reference in New Issue
Block a user