mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-10 18:42:17 +00:00
Merge pull request #28 from csukuangfj/streaming-feature-extractor
Start to add streaming feature extractors.
This commit is contained in:
commit
b72fc599fd
4
.github/workflows/style_check.yml
vendored
4
.github/workflows/style_check.yml
vendored
@ -45,7 +45,9 @@ jobs:
|
|||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install --upgrade pip black flake8
|
python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
|
||||||
|
# See https://github.com/psf/black/issues/2964
|
||||||
|
# The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
|
||||||
|
|
||||||
- name: Run flake8
|
- name: Run flake8
|
||||||
shell: bash
|
shell: bash
|
||||||
|
35
README.md
35
README.md
@ -31,6 +31,17 @@ features = fbank(wave)
|
|||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
|
<tr>
|
||||||
|
<td>Streaming FBANK</td>
|
||||||
|
<td><code>kaldifeat.FbankOptions</code></td>
|
||||||
|
<td><code>kaldifeat.OnlineFbank</code></td>
|
||||||
|
<td>
|
||||||
|
See <a href="./kaldifeat/python/tests/test_fbank.py">
|
||||||
|
./kaldifeat/python/tests/test_fbank.py
|
||||||
|
</a>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
|
||||||
<tr>
|
<tr>
|
||||||
<td>MFCC</td>
|
<td>MFCC</td>
|
||||||
<td><code>kaldifeat.MfccOptions</code></td>
|
<td><code>kaldifeat.MfccOptions</code></td>
|
||||||
@ -45,6 +56,17 @@ features = mfcc(wave)
|
|||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
|
<tr>
|
||||||
|
<td>Streaming MFCC</td>
|
||||||
|
<td><code>kaldifeat.MfccOptions</code></td>
|
||||||
|
<td><code>kaldifeat.OnlineMfcc</code></td>
|
||||||
|
<td>
|
||||||
|
See <a href="./kaldifeat/python/tests/test_mfcc.py">
|
||||||
|
./kaldifeat/python/tests/test_mfcc.py
|
||||||
|
</a>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
|
||||||
<tr>
|
<tr>
|
||||||
<td>PLP</td>
|
<td>PLP</td>
|
||||||
<td><code>kaldifeat.PlpOptions</code></td>
|
<td><code>kaldifeat.PlpOptions</code></td>
|
||||||
@ -59,6 +81,17 @@ features = plp(wave)
|
|||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
|
<tr>
|
||||||
|
<td>Streaming PLP</td>
|
||||||
|
<td><code>kaldifeat.PlpOptions</code></td>
|
||||||
|
<td><code>kaldifeat.OnlinePlp</code></td>
|
||||||
|
<td>
|
||||||
|
See <a href="./kaldifeat/python/tests/test_plp.py">
|
||||||
|
./kaldifeat/python/tests/test_plp.py
|
||||||
|
</a>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
|
||||||
<tr>
|
<tr>
|
||||||
<td>Spectorgram</td>
|
<td>Spectorgram</td>
|
||||||
<td><code>kaldifeat.SpectrogramOptions</code></td>
|
<td><code>kaldifeat.SpectrogramOptions</code></td>
|
||||||
@ -88,6 +121,8 @@ The following kaldi-compatible commandline tools are implemented:
|
|||||||
|
|
||||||
(**NOTE**: We will implement other types of features, e.g., Pitch, ivector, etc, soon.)
|
(**NOTE**: We will implement other types of features, e.g., Pitch, ivector, etc, soon.)
|
||||||
|
|
||||||
|
**HINT**: It supports also streaming feature extractors for Fbank, MFCC, and Plp.
|
||||||
|
|
||||||
# Usage
|
# Usage
|
||||||
|
|
||||||
Let us first generate a test wave using sox:
|
Let us first generate a test wave using sox:
|
||||||
|
@ -9,6 +9,7 @@ set(kaldifeat_srcs
|
|||||||
feature-window.cc
|
feature-window.cc
|
||||||
matrix-functions.cc
|
matrix-functions.cc
|
||||||
mel-computations.cc
|
mel-computations.cc
|
||||||
|
online-feature.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
|
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
|
||||||
@ -40,6 +41,7 @@ if(kaldifeat_BUILD_TESTS)
|
|||||||
# please sort the source files alphabetically
|
# please sort the source files alphabetically
|
||||||
set(test_srcs
|
set(test_srcs
|
||||||
feature-window-test.cc
|
feature-window-test.cc
|
||||||
|
online-feature-test.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
foreach(source IN LISTS test_srcs)
|
foreach(source IN LISTS test_srcs)
|
||||||
|
@ -62,6 +62,10 @@ class OfflineFeatureTpl {
|
|||||||
int32_t Dim() const { return computer_.Dim(); }
|
int32_t Dim() const { return computer_.Dim(); }
|
||||||
const Options &GetOptions() const { return computer_.GetOptions(); }
|
const Options &GetOptions() const { return computer_.GetOptions(); }
|
||||||
|
|
||||||
|
const FrameExtractionOptions &GetFrameOptions() const {
|
||||||
|
return GetOptions().frame_opts;
|
||||||
|
}
|
||||||
|
|
||||||
// Copy constructor.
|
// Copy constructor.
|
||||||
OfflineFeatureTpl(const OfflineFeatureTpl<F> &) = delete;
|
OfflineFeatureTpl(const OfflineFeatureTpl<F> &) = delete;
|
||||||
OfflineFeatureTpl<F> &operator=(const OfflineFeatureTpl<F> &) = delete;
|
OfflineFeatureTpl<F> &operator=(const OfflineFeatureTpl<F> &) = delete;
|
||||||
|
@ -161,19 +161,20 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value) {
|
|||||||
#if 1
|
#if 1
|
||||||
return wave + rand_gauss * dither_value;
|
return wave + rand_gauss * dither_value;
|
||||||
#else
|
#else
|
||||||
// use in-place version of wave and change its to pointer type
|
// use in-place version of wave and change it to pointer type
|
||||||
wave_->add_(rand_gauss, dither_value);
|
wave_->add_(rand_gauss, dither_value);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) {
|
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) {
|
||||||
using namespace torch::indexing; // It imports: Slice, None // NOLINT
|
|
||||||
if (preemph_coeff == 0.0f) return wave;
|
if (preemph_coeff == 0.0f) return wave;
|
||||||
|
|
||||||
KALDIFEAT_ASSERT(preemph_coeff >= 0.0f && preemph_coeff <= 1.0f);
|
KALDIFEAT_ASSERT(preemph_coeff >= 0.0f && preemph_coeff <= 1.0f);
|
||||||
|
|
||||||
torch::Tensor ans = torch::empty_like(wave);
|
torch::Tensor ans = torch::empty_like(wave);
|
||||||
|
|
||||||
|
using torch::indexing::None;
|
||||||
|
using torch::indexing::Slice;
|
||||||
// right = wave[:, 1:]
|
// right = wave[:, 1:]
|
||||||
torch::Tensor right = wave.index({"...", Slice(1, None, None)});
|
torch::Tensor right = wave.index({"...", Slice(1, None, None)});
|
||||||
|
|
||||||
@ -188,4 +189,59 @@ torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) {
|
|||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
|
||||||
|
int32_t f, const FrameExtractionOptions &opts) {
|
||||||
|
KALDIFEAT_ASSERT(sample_offset >= 0 && wave.numel() != 0);
|
||||||
|
|
||||||
|
int32_t frame_length = opts.WindowSize();
|
||||||
|
int64_t num_samples = sample_offset + wave.numel();
|
||||||
|
int64_t start_sample = FirstSampleOfFrame(f, opts);
|
||||||
|
int64_t end_sample = start_sample + frame_length;
|
||||||
|
|
||||||
|
if (opts.snip_edges) {
|
||||||
|
KALDIFEAT_ASSERT(start_sample >= sample_offset &&
|
||||||
|
end_sample <= num_samples);
|
||||||
|
} else {
|
||||||
|
KALDIFEAT_ASSERT(sample_offset == 0 || start_sample >= sample_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
// wave_start and wave_end are start and end indexes into 'wave', for the
|
||||||
|
// piece of wave that we're trying to extract.
|
||||||
|
int32_t wave_start = static_cast<int32_t>(start_sample - sample_offset);
|
||||||
|
int32_t wave_end = wave_start + frame_length;
|
||||||
|
|
||||||
|
if (wave_start >= 0 && wave_end <= wave.numel()) {
|
||||||
|
// the normal case -- no edge effects to consider.
|
||||||
|
// return wave[wave_start:wave_end]
|
||||||
|
return wave.index({torch::indexing::Slice(wave_start, wave_end)});
|
||||||
|
} else {
|
||||||
|
torch::Tensor window = torch::empty({frame_length}, torch::kFloat);
|
||||||
|
auto p_window = window.accessor<float, 1>();
|
||||||
|
auto p_wave = wave.accessor<float, 1>();
|
||||||
|
|
||||||
|
// Deal with any end effects by reflection, if needed. This code will only
|
||||||
|
// be reached for about two frames per utterance, so we don't concern
|
||||||
|
// ourselves excessively with efficiency.
|
||||||
|
int32_t wave_dim = wave.numel();
|
||||||
|
for (int32_t s = 0; s != frame_length; ++s) {
|
||||||
|
int32_t s_in_wave = s + wave_start;
|
||||||
|
while (s_in_wave < 0 || s_in_wave >= wave_dim) {
|
||||||
|
// reflect around the beginning or end of the wave.
|
||||||
|
// e.g. -1 -> 0, -2 -> 1.
|
||||||
|
// dim -> dim - 1, dim + 1 -> dim - 2.
|
||||||
|
// the code supports repeated reflections, although this
|
||||||
|
// would only be needed in pathological cases.
|
||||||
|
if (s_in_wave < 0) {
|
||||||
|
s_in_wave = -s_in_wave - 1;
|
||||||
|
} else {
|
||||||
|
s_in_wave = 2 * wave_dim - 1 - s_in_wave;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p_window[s] = p_wave[s_in_wave];
|
||||||
|
}
|
||||||
|
return window;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace kaldifeat
|
} // namespace kaldifeat
|
||||||
|
@ -44,7 +44,11 @@ struct FrameExtractionOptions {
|
|||||||
bool snip_edges = true;
|
bool snip_edges = true;
|
||||||
// bool allow_downsample = false;
|
// bool allow_downsample = false;
|
||||||
// bool allow_upsample = false;
|
// bool allow_upsample = false;
|
||||||
// int32_t max_feature_vectors = -1;
|
|
||||||
|
// Used for streaming feature extraction. It indicates the number
|
||||||
|
// of feature frames to keep in the recycling vector. -1 means to
|
||||||
|
// keep all feature frames.
|
||||||
|
int32_t max_feature_vectors = -1;
|
||||||
|
|
||||||
int32_t WindowShift() const {
|
int32_t WindowShift() const {
|
||||||
return static_cast<int32_t>(samp_freq * 0.001f * frame_shift_ms);
|
return static_cast<int32_t>(samp_freq * 0.001f * frame_shift_ms);
|
||||||
@ -71,7 +75,7 @@ struct FrameExtractionOptions {
|
|||||||
KALDIFEAT_PRINT(snip_edges);
|
KALDIFEAT_PRINT(snip_edges);
|
||||||
// KALDIFEAT_PRINT(allow_downsample);
|
// KALDIFEAT_PRINT(allow_downsample);
|
||||||
// KALDIFEAT_PRINT(allow_upsample);
|
// KALDIFEAT_PRINT(allow_upsample);
|
||||||
// KALDIFEAT_PRINT(max_feature_vectors);
|
KALDIFEAT_PRINT(max_feature_vectors);
|
||||||
#undef KALDIFEAT_PRINT
|
#undef KALDIFEAT_PRINT
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
@ -100,11 +104,11 @@ class FeatureWindowFunction {
|
|||||||
|
|
||||||
@param [in] flush True if we are asserting that this number of samples
|
@param [in] flush True if we are asserting that this number of samples
|
||||||
is 'all there is', false if we expecting more data to possibly come in. This
|
is 'all there is', false if we expecting more data to possibly come in. This
|
||||||
only makes a difference to the answer if opts.snips_edges
|
only makes a difference to the answer
|
||||||
== false. For offline feature extraction you always want flush ==
|
if opts.snips_edges== false. For offline feature extraction you always want
|
||||||
true. In an online-decoding context, once you know (or decide)
|
flush == true. In an online-decoding context, once you know (or decide) that
|
||||||
that no more data is coming in, you'd call it with flush == true at the end
|
no more data is coming in, you'd call it with flush == true at the end to
|
||||||
to flush out any remaining data.
|
flush out any remaining data.
|
||||||
*/
|
*/
|
||||||
int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts,
|
int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts,
|
||||||
bool flush = true);
|
bool flush = true);
|
||||||
@ -133,6 +137,29 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value);
|
|||||||
|
|
||||||
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave);
|
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave);
|
||||||
|
|
||||||
|
/*
|
||||||
|
ExtractWindow() extracts "frame_length" samples from the given waveform.
|
||||||
|
Note: This function only extracts "frame_length" samples
|
||||||
|
from the input waveform, without any further processing.
|
||||||
|
|
||||||
|
@param [in] sample_offset If 'wave' is not the entire waveform, but
|
||||||
|
part of it to the left has been discarded, then the
|
||||||
|
number of samples prior to 'wave' that we have
|
||||||
|
already discarded. Set this to zero if you are
|
||||||
|
processing the entire waveform in one piece, or
|
||||||
|
if you get 'no matching function' compilation
|
||||||
|
errors when updating the code.
|
||||||
|
@param [in] wave The waveform
|
||||||
|
@param [in] f The frame index to be extracted, with
|
||||||
|
0 <= f < NumFrames(sample_offset + wave.numel(), opts, true)
|
||||||
|
@param [in] opts The options class to be used
|
||||||
|
@return Return a tensor containing "frame_length" samples extracted from
|
||||||
|
`wave`, without any further processing. Its shape is
|
||||||
|
(1, frame_length).
|
||||||
|
*/
|
||||||
|
torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
|
||||||
|
int32_t f, const FrameExtractionOptions &opts);
|
||||||
|
|
||||||
} // namespace kaldifeat
|
} // namespace kaldifeat
|
||||||
|
|
||||||
#endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_
|
#endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_
|
||||||
|
89
kaldifeat/csrc/online-feature-itf.h
Normal file
89
kaldifeat/csrc/online-feature-itf.h
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
// kaldifeat/csrc/online-feature-itf.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
// This file is copied/modified from kaldi/src/itf/online-feature-itf.h
|
||||||
|
|
||||||
|
#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
|
||||||
|
#define KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "torch/script.h"
|
||||||
|
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
class OnlineFeatureInterface {
|
||||||
|
public:
|
||||||
|
virtual ~OnlineFeatureInterface() = default;
|
||||||
|
|
||||||
|
virtual int32_t Dim() const = 0; /// returns the feature dimension.
|
||||||
|
//
|
||||||
|
// Returns frame shift in seconds. Helps to estimate duration from frame
|
||||||
|
// counts.
|
||||||
|
virtual float FrameShiftInSeconds() const = 0;
|
||||||
|
|
||||||
|
/// Returns the total number of frames, since the start of the utterance, that
|
||||||
|
/// are now available. In an online-decoding context, this will likely
|
||||||
|
/// increase with time as more data becomes available.
|
||||||
|
virtual int32_t NumFramesReady() const = 0;
|
||||||
|
|
||||||
|
/// Returns true if this is the last frame. Frame indices are zero-based, so
|
||||||
|
/// the first frame is zero. IsLastFrame(-1) will return false, unless the
|
||||||
|
/// file is empty (which is a case that I'm not sure all the code will handle,
|
||||||
|
/// so be careful). This function may return false for some frame if we
|
||||||
|
/// haven't yet decided to terminate decoding, but later true if we decide to
|
||||||
|
/// terminate decoding. This function exists mainly to correctly handle end
|
||||||
|
/// effects in feature extraction, and is not a mechanism to determine how
|
||||||
|
/// many frames are in the decodable object (as it used to be, and for
|
||||||
|
/// backward compatibility, still is, in the Decodable interface).
|
||||||
|
virtual bool IsLastFrame(int32_t frame) const = 0;
|
||||||
|
|
||||||
|
/// Gets the feature vector for this frame. Before calling this for a given
|
||||||
|
/// frame, it is assumed that you called NumFramesReady() and it returned a
|
||||||
|
/// number greater than "frame". Otherwise this call will likely crash with
|
||||||
|
/// an assert failure. This function is not declared const, in case there is
|
||||||
|
/// some kind of caching going on, but most of the time it shouldn't modify
|
||||||
|
/// the class.
|
||||||
|
///
|
||||||
|
/// The returned tensor has shape (1, Dim()).
|
||||||
|
virtual torch::Tensor GetFrame(int32_t frame) = 0;
|
||||||
|
|
||||||
|
/// This is like GetFrame() but for a collection of frames. There is a
|
||||||
|
/// default implementation that just gets the frames one by one, but it
|
||||||
|
/// may be overridden for efficiency by child classes (since sometimes
|
||||||
|
/// it's more efficient to do things in a batch).
|
||||||
|
///
|
||||||
|
/// The returned tensor has shape (frames.size(), Dim()).
|
||||||
|
virtual std::vector<torch::Tensor> GetFrames(
|
||||||
|
const std::vector<int32_t> &frames) {
|
||||||
|
std::vector<torch::Tensor> features;
|
||||||
|
features.reserve(frames.size());
|
||||||
|
|
||||||
|
for (auto i : frames) {
|
||||||
|
torch::Tensor f = GetFrame(i);
|
||||||
|
features.push_back(std::move(f));
|
||||||
|
}
|
||||||
|
return features;
|
||||||
|
#if 0
|
||||||
|
return torch::cat(features, /*dim*/ 0);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This would be called from the application, when you get more wave data.
|
||||||
|
/// Note: the sampling_rate is typically only provided so the code can assert
|
||||||
|
/// that it matches the sampling rate expected in the options.
|
||||||
|
virtual void AcceptWaveform(float sampling_rate,
|
||||||
|
const torch::Tensor &waveform) = 0;
|
||||||
|
|
||||||
|
/// InputFinished() tells the class you won't be providing any
|
||||||
|
/// more waveform. This will help flush out the last few frames
|
||||||
|
/// of delta or LDA features (it will typically affect the return value
|
||||||
|
/// of IsLastFrame.
|
||||||
|
virtual void InputFinished() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
||||||
|
|
||||||
|
#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
|
49
kaldifeat/csrc/online-feature-test.cc
Normal file
49
kaldifeat/csrc/online-feature-test.cc
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
// kaldifeat/csrc/online-feature-test.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
#include "kaldifeat/csrc/online-feature.h"
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
TEST(RecyclingVector, TestUnlimited) {
|
||||||
|
RecyclingVector v(-1);
|
||||||
|
constexpr int32_t N = 100;
|
||||||
|
for (int32_t i = 0; i != N; ++i) {
|
||||||
|
torch::Tensor t = torch::tensor({i, i + 1, i + 2});
|
||||||
|
v.PushBack(t);
|
||||||
|
}
|
||||||
|
ASSERT_EQ(v.Size(), N);
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != N; ++i) {
|
||||||
|
torch::Tensor t = v.At(i);
|
||||||
|
torch::Tensor expected = torch::tensor({i, i + 1, i + 2});
|
||||||
|
EXPECT_TRUE(t.equal(expected));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RecyclingVector, Testlimited) {
|
||||||
|
constexpr int32_t K = 3;
|
||||||
|
constexpr int32_t N = 10;
|
||||||
|
RecyclingVector v(K);
|
||||||
|
for (int32_t i = 0; i != N; ++i) {
|
||||||
|
torch::Tensor t = torch::tensor({i, i + 1, i + 2});
|
||||||
|
v.PushBack(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_EQ(v.Size(), N);
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < N - K; ++i) {
|
||||||
|
ASSERT_DEATH(v.At(i), "");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int32_t i = N - K; i != N; ++i) {
|
||||||
|
torch::Tensor t = v.At(i);
|
||||||
|
torch::Tensor expected = torch::tensor({i, i + 1, i + 2});
|
||||||
|
EXPECT_TRUE(t.equal(expected));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
133
kaldifeat/csrc/online-feature.cc
Normal file
133
kaldifeat/csrc/online-feature.cc
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
// kaldifeat/csrc/online-feature.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
// This file is copied/modified from kaldi/src/feat/online-feature.cc
|
||||||
|
|
||||||
|
#include "kaldifeat/csrc/online-feature.h"
|
||||||
|
|
||||||
|
#include "kaldifeat/csrc/feature-window.h"
|
||||||
|
#include "kaldifeat/csrc/log.h"
|
||||||
|
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
RecyclingVector::RecyclingVector(int32_t items_to_hold)
|
||||||
|
: items_to_hold_(items_to_hold == 0 ? -1 : items_to_hold),
|
||||||
|
first_available_index_(0) {}
|
||||||
|
|
||||||
|
torch::Tensor RecyclingVector::At(int32_t index) const {
|
||||||
|
if (index < first_available_index_) {
|
||||||
|
KALDIFEAT_ERR << "Attempted to retrieve feature vector that was "
|
||||||
|
"already removed by the RecyclingVector (index = "
|
||||||
|
<< index << "; "
|
||||||
|
<< "first_available_index = " << first_available_index_
|
||||||
|
<< "; "
|
||||||
|
<< "size = " << Size() << ")";
|
||||||
|
}
|
||||||
|
// 'at' does size checking.
|
||||||
|
return items_.at(index - first_available_index_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecyclingVector::PushBack(torch::Tensor item) {
|
||||||
|
// Note: -1 is a larger number when treated as unsigned
|
||||||
|
if (items_.size() == static_cast<size_t>(items_to_hold_)) {
|
||||||
|
items_.pop_front();
|
||||||
|
++first_available_index_;
|
||||||
|
}
|
||||||
|
items_.push_back(item);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t RecyclingVector::Size() const {
|
||||||
|
return first_available_index_ + static_cast<int32_t>(items_.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class C>
|
||||||
|
OnlineGenericBaseFeature<C>::OnlineGenericBaseFeature(
|
||||||
|
const typename C::Options &opts)
|
||||||
|
: computer_(opts),
|
||||||
|
window_function_(opts.frame_opts, opts.device),
|
||||||
|
features_(opts.frame_opts.max_feature_vectors),
|
||||||
|
input_finished_(false),
|
||||||
|
waveform_offset_(0) {}
|
||||||
|
|
||||||
|
template <class C>
|
||||||
|
void OnlineGenericBaseFeature<C>::AcceptWaveform(
|
||||||
|
float sampling_rate, const torch::Tensor &original_waveform) {
|
||||||
|
if (original_waveform.numel() == 0) return; // Nothing to do.
|
||||||
|
|
||||||
|
KALDIFEAT_ASSERT(original_waveform.dim() == 1);
|
||||||
|
KALDIFEAT_ASSERT(sampling_rate == computer_.GetFrameOptions().samp_freq);
|
||||||
|
|
||||||
|
if (input_finished_)
|
||||||
|
KALDIFEAT_ERR << "AcceptWaveform called after InputFinished() was called.";
|
||||||
|
|
||||||
|
if (waveform_remainder_.numel() == 0) {
|
||||||
|
waveform_remainder_ = original_waveform;
|
||||||
|
} else {
|
||||||
|
waveform_remainder_ =
|
||||||
|
torch::cat({waveform_remainder_, original_waveform}, /*dim*/ 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputeFeatures();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class C>
|
||||||
|
void OnlineGenericBaseFeature<C>::InputFinished() {
|
||||||
|
input_finished_ = true;
|
||||||
|
ComputeFeatures();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class C>
|
||||||
|
void OnlineGenericBaseFeature<C>::ComputeFeatures() {
|
||||||
|
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
|
||||||
|
|
||||||
|
int64_t num_samples_total = waveform_offset_ + waveform_remainder_.numel();
|
||||||
|
int32_t num_frames_old = features_.Size();
|
||||||
|
int32_t num_frames_new =
|
||||||
|
NumFrames(num_samples_total, frame_opts, input_finished_);
|
||||||
|
|
||||||
|
KALDIFEAT_ASSERT(num_frames_new >= num_frames_old);
|
||||||
|
|
||||||
|
// note: this online feature-extraction code does not support VTLN.
|
||||||
|
float vtln_warp = 1.0;
|
||||||
|
|
||||||
|
for (int32_t frame = num_frames_old; frame < num_frames_new; ++frame) {
|
||||||
|
torch::Tensor window =
|
||||||
|
ExtractWindow(waveform_offset_, waveform_remainder_, frame, frame_opts);
|
||||||
|
|
||||||
|
// TODO(fangjun): We can compute all feature frames at once
|
||||||
|
torch::Tensor this_feature =
|
||||||
|
computer_.ComputeFeatures(window.unsqueeze(0), vtln_warp);
|
||||||
|
features_.PushBack(this_feature);
|
||||||
|
}
|
||||||
|
|
||||||
|
// OK, we will now discard any portion of the signal that will not be
|
||||||
|
// necessary to compute frames in the future.
|
||||||
|
int64_t first_sample_of_next_frame =
|
||||||
|
FirstSampleOfFrame(num_frames_new, frame_opts);
|
||||||
|
int32_t samples_to_discard = first_sample_of_next_frame - waveform_offset_;
|
||||||
|
if (samples_to_discard > 0) {
|
||||||
|
// discard the leftmost part of the waveform that we no longer need.
|
||||||
|
int32_t new_num_samples = waveform_remainder_.numel() - samples_to_discard;
|
||||||
|
if (new_num_samples <= 0) {
|
||||||
|
// odd, but we'll try to handle it.
|
||||||
|
waveform_offset_ += waveform_remainder_.numel();
|
||||||
|
waveform_remainder_.resize_({0});
|
||||||
|
} else {
|
||||||
|
using torch::indexing::None;
|
||||||
|
using torch::indexing::Slice;
|
||||||
|
|
||||||
|
waveform_remainder_ =
|
||||||
|
waveform_remainder_.index({Slice(samples_to_discard, None)});
|
||||||
|
|
||||||
|
waveform_offset_ += samples_to_discard;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// instantiate the templates defined here for MFCC, PLP and filterbank classes.
|
||||||
|
template class OnlineGenericBaseFeature<Mfcc>;
|
||||||
|
template class OnlineGenericBaseFeature<Plp>;
|
||||||
|
template class OnlineGenericBaseFeature<Fbank>;
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
127
kaldifeat/csrc/online-feature.h
Normal file
127
kaldifeat/csrc/online-feature.h
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
// kaldifeat/csrc/online-feature.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
// This file is copied/modified from kaldi/src/feat/online-feature.h
|
||||||
|
|
||||||
|
#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_H_
|
||||||
|
#define KALDIFEAT_CSRC_ONLINE_FEATURE_H_
|
||||||
|
|
||||||
|
#include <deque>
|
||||||
|
|
||||||
|
#include "kaldifeat/csrc/feature-fbank.h"
|
||||||
|
#include "kaldifeat/csrc/feature-mfcc.h"
|
||||||
|
#include "kaldifeat/csrc/feature-plp.h"
|
||||||
|
#include "kaldifeat/csrc/feature-window.h"
|
||||||
|
#include "kaldifeat/csrc/online-feature-itf.h"
|
||||||
|
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
/// This class serves as a storage for feature vectors with an option to limit
|
||||||
|
/// the memory usage by removing old elements. The deleted frames indices are
|
||||||
|
/// "remembered" so that regardless of the MAX_ITEMS setting, the user always
|
||||||
|
/// provides the indices as if no deletion was being performed.
|
||||||
|
/// This is useful when processing very long recordings which would otherwise
|
||||||
|
/// cause the memory to eventually blow up when the features are not being
|
||||||
|
/// removed.
|
||||||
|
class RecyclingVector {
|
||||||
|
public:
|
||||||
|
/// By default it does not remove any elements.
|
||||||
|
explicit RecyclingVector(int32_t items_to_hold = -1);
|
||||||
|
|
||||||
|
~RecyclingVector() = default;
|
||||||
|
RecyclingVector(const RecyclingVector &) = delete;
|
||||||
|
RecyclingVector &operator=(const RecyclingVector &) = delete;
|
||||||
|
|
||||||
|
torch::Tensor At(int32_t index) const;
|
||||||
|
|
||||||
|
void PushBack(torch::Tensor item);
|
||||||
|
|
||||||
|
/// This method returns the size as if no "recycling" had happened,
|
||||||
|
/// i.e. equivalent to the number of times the PushBack method has been
|
||||||
|
/// called.
|
||||||
|
int32_t Size() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::deque<torch::Tensor> items_;
|
||||||
|
int32_t items_to_hold_;
|
||||||
|
int32_t first_available_index_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// This is a templated class for online feature extraction;
|
||||||
|
/// it's templated on a class like MfccComputer or PlpComputer
|
||||||
|
/// that does the basic feature extraction.
|
||||||
|
template <class C>
|
||||||
|
class OnlineGenericBaseFeature : public OnlineFeatureInterface {
|
||||||
|
public:
|
||||||
|
// Constructor from options class
|
||||||
|
explicit OnlineGenericBaseFeature(const typename C::Options &opts);
|
||||||
|
|
||||||
|
int32_t Dim() const override { return computer_.Dim(); }
|
||||||
|
|
||||||
|
float FrameShiftInSeconds() const override {
|
||||||
|
return computer_.GetFrameOptions().frame_shift_ms / 1000.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t NumFramesReady() const override { return features_.Size(); }
|
||||||
|
|
||||||
|
// Note: IsLastFrame() will only ever return true if you have called
|
||||||
|
// InputFinished() (and this frame is the last frame).
|
||||||
|
bool IsLastFrame(int32_t frame) const override {
|
||||||
|
return input_finished_ && frame == NumFramesReady() - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor GetFrame(int32_t frame) override { return features_.At(frame); }
|
||||||
|
|
||||||
|
// This would be called from the application, when you get
|
||||||
|
// more wave data. Note: the sampling_rate is only provided so
|
||||||
|
// the code can assert that it matches the sampling rate
|
||||||
|
// expected in the options.
|
||||||
|
void AcceptWaveform(float sampling_rate,
|
||||||
|
const torch::Tensor &waveform) override;
|
||||||
|
|
||||||
|
// InputFinished() tells the class you won't be providing any
|
||||||
|
// more waveform. This will help flush out the last frame or two
|
||||||
|
// of features, in the case where snip-edges == false; it also
|
||||||
|
// affects the return value of IsLastFrame().
|
||||||
|
void InputFinished() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// This function computes any additional feature frames that it is possible to
|
||||||
|
// compute from 'waveform_remainder_', which at this point may contain more
|
||||||
|
// than just a remainder-sized quantity (because AcceptWaveform() appends to
|
||||||
|
// waveform_remainder_ before calling this function). It adds these feature
|
||||||
|
// frames to features_, and shifts off any now-unneeded samples of input from
|
||||||
|
// waveform_remainder_ while incrementing waveform_offset_ by the same amount.
|
||||||
|
void ComputeFeatures();
|
||||||
|
|
||||||
|
C computer_; // class that does the MFCC or PLP or filterbank computation
|
||||||
|
|
||||||
|
FeatureWindowFunction window_function_;
|
||||||
|
|
||||||
|
// features_ is the Mfcc or Plp or Fbank features that we have already
|
||||||
|
// computed.
|
||||||
|
|
||||||
|
RecyclingVector features_;
|
||||||
|
|
||||||
|
// True if the user has called "InputFinished()"
|
||||||
|
bool input_finished_;
|
||||||
|
|
||||||
|
// waveform_offset_ is the number of samples of waveform that we have
|
||||||
|
// already discarded, i.e. that were prior to 'waveform_remainder_'.
|
||||||
|
int64_t waveform_offset_;
|
||||||
|
|
||||||
|
// waveform_remainder_ is a short piece of waveform that we may need to keep
|
||||||
|
// after extracting all the whole frames we can (whatever length of feature
|
||||||
|
// will be required for the next phase of computation).
|
||||||
|
// It is a 1-D tensor
|
||||||
|
torch::Tensor waveform_remainder_;
|
||||||
|
};
|
||||||
|
|
||||||
|
using OnlineMfcc = OnlineGenericBaseFeature<Mfcc>;
|
||||||
|
using OnlinePlp = OnlineGenericBaseFeature<Plp>;
|
||||||
|
using OnlineFbank = OnlineGenericBaseFeature<Fbank>;
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
||||||
|
|
||||||
|
#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_H_
|
@ -7,6 +7,7 @@ pybind11_add_module(_kaldifeat
|
|||||||
feature-window.cc
|
feature-window.cc
|
||||||
kaldifeat.cc
|
kaldifeat.cc
|
||||||
mel-computations.cc
|
mel-computations.cc
|
||||||
|
online-feature.cc
|
||||||
utils.cc
|
utils.cc
|
||||||
)
|
)
|
||||||
target_link_libraries(_kaldifeat PRIVATE kaldifeat_core)
|
target_link_libraries(_kaldifeat PRIVATE kaldifeat_core)
|
||||||
|
@ -25,6 +25,7 @@ static void PybindFrameExtractionOptions(py::module &m) {
|
|||||||
.def_readwrite("round_to_power_of_two", &PyClass::round_to_power_of_two)
|
.def_readwrite("round_to_power_of_two", &PyClass::round_to_power_of_two)
|
||||||
.def_readwrite("blackman_coeff", &PyClass::blackman_coeff)
|
.def_readwrite("blackman_coeff", &PyClass::blackman_coeff)
|
||||||
.def_readwrite("snip_edges", &PyClass::snip_edges)
|
.def_readwrite("snip_edges", &PyClass::snip_edges)
|
||||||
|
.def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)
|
||||||
.def("as_dict",
|
.def("as_dict",
|
||||||
[](const PyClass &self) -> py::dict { return AsDict(self); })
|
[](const PyClass &self) -> py::dict { return AsDict(self); })
|
||||||
.def_static("from_dict",
|
.def_static("from_dict",
|
||||||
@ -35,8 +36,6 @@ static void PybindFrameExtractionOptions(py::module &m) {
|
|||||||
.def_readwrite("allow_downsample",
|
.def_readwrite("allow_downsample",
|
||||||
&PyClass::allow_downsample)
|
&PyClass::allow_downsample)
|
||||||
.def_readwrite("allow_upsample", &PyClass::allow_upsample)
|
.def_readwrite("allow_upsample", &PyClass::allow_upsample)
|
||||||
.def_readwrite("max_feature_vectors",
|
|
||||||
&PyClass::max_feature_vectors)
|
|
||||||
#endif
|
#endif
|
||||||
.def("__str__",
|
.def("__str__",
|
||||||
[](const PyClass &self) -> std::string { return self.ToString(); })
|
[](const PyClass &self) -> std::string { return self.ToString(); })
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
#include "kaldifeat/python/csrc/feature-spectrogram.h"
|
#include "kaldifeat/python/csrc/feature-spectrogram.h"
|
||||||
#include "kaldifeat/python/csrc/feature-window.h"
|
#include "kaldifeat/python/csrc/feature-window.h"
|
||||||
#include "kaldifeat/python/csrc/mel-computations.h"
|
#include "kaldifeat/python/csrc/mel-computations.h"
|
||||||
|
#include "kaldifeat/python/csrc/online-feature.h"
|
||||||
#include "torch/torch.h"
|
#include "torch/torch.h"
|
||||||
|
|
||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
@ -24,6 +25,7 @@ PYBIND11_MODULE(_kaldifeat, m) {
|
|||||||
PybindFeatureMfcc(m);
|
PybindFeatureMfcc(m);
|
||||||
PybindFeaturePlp(m);
|
PybindFeaturePlp(m);
|
||||||
PybindFeatureSpectrogram(m);
|
PybindFeatureSpectrogram(m);
|
||||||
|
PybindOnlineFeature(m);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace kaldifeat
|
} // namespace kaldifeat
|
||||||
|
37
kaldifeat/python/csrc/online-feature.cc
Normal file
37
kaldifeat/python/csrc/online-feature.cc
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
// kaldifeat/python/csrc/online-feature.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
#include "kaldifeat/python/csrc/online-feature.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "kaldifeat/csrc/online-feature.h"
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
template <typename C>
|
||||||
|
void PybindOnlineFeatureTpl(py::module &m, const std::string &class_name,
|
||||||
|
const std::string &class_help_doc = "") {
|
||||||
|
using PyClass = OnlineGenericBaseFeature<C>;
|
||||||
|
using Options = typename C::Options;
|
||||||
|
py::class_<PyClass>(m, class_name.c_str(), class_help_doc.c_str())
|
||||||
|
.def(py::init<const Options &>(), py::arg("opts"))
|
||||||
|
.def_property_readonly("dim", &PyClass::Dim)
|
||||||
|
.def_property_readonly("frame_shift_in_seconds",
|
||||||
|
&PyClass::FrameShiftInSeconds)
|
||||||
|
.def_property_readonly("num_frames_ready", &PyClass::NumFramesReady)
|
||||||
|
.def("is_last_frame", &PyClass::IsLastFrame, py::arg("frame"))
|
||||||
|
.def("get_frame", &PyClass::GetFrame, py::arg("frame"))
|
||||||
|
.def("get_frames", &PyClass::GetFrames, py::arg("frames"))
|
||||||
|
.def("accept_waveform", &PyClass::AcceptWaveform,
|
||||||
|
py::arg("sampling_rate"), py::arg("waveform"))
|
||||||
|
.def("input_finished", &PyClass::InputFinished);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PybindOnlineFeature(py::module &m) {
|
||||||
|
PybindOnlineFeatureTpl<Mfcc>(m, "OnlineMfcc");
|
||||||
|
PybindOnlineFeatureTpl<Fbank>(m, "OnlineFbank");
|
||||||
|
PybindOnlineFeatureTpl<Plp>(m, "OnlinePlp");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
16
kaldifeat/python/csrc/online-feature.h
Normal file
16
kaldifeat/python/csrc/online-feature.h
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
// kaldifeat/python/csrc/online-feature.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
#ifndef KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_
|
||||||
|
#define KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_
|
||||||
|
|
||||||
|
#include "kaldifeat/python/csrc/kaldifeat.h"
|
||||||
|
|
||||||
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
void PybindOnlineFeature(py::module &m);
|
||||||
|
|
||||||
|
} // namespace kaldifeat
|
||||||
|
|
||||||
|
#endif // KALDIFEAT_PYTHON_CSRC_ONLINE_FEATURE_H_
|
@ -30,6 +30,7 @@ FrameExtractionOptions FrameExtractionOptionsFromDict(py::dict dict) {
|
|||||||
FROM_DICT(bool_, round_to_power_of_two);
|
FROM_DICT(bool_, round_to_power_of_two);
|
||||||
FROM_DICT(float_, blackman_coeff);
|
FROM_DICT(float_, blackman_coeff);
|
||||||
FROM_DICT(bool_, snip_edges);
|
FROM_DICT(bool_, snip_edges);
|
||||||
|
FROM_DICT(int_, max_feature_vectors);
|
||||||
|
|
||||||
return opts;
|
return opts;
|
||||||
}
|
}
|
||||||
@ -47,6 +48,7 @@ py::dict AsDict(const FrameExtractionOptions &opts) {
|
|||||||
AS_DICT(round_to_power_of_two);
|
AS_DICT(round_to_power_of_two);
|
||||||
AS_DICT(blackman_coeff);
|
AS_DICT(blackman_coeff);
|
||||||
AS_DICT(snip_edges);
|
AS_DICT(snip_edges);
|
||||||
|
AS_DICT(max_feature_vectors);
|
||||||
|
|
||||||
return dict;
|
return dict;
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,7 @@ from _kaldifeat import (
|
|||||||
SpectrogramOptions,
|
SpectrogramOptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .fbank import Fbank
|
from .fbank import Fbank, OnlineFbank
|
||||||
from .mfcc import Mfcc
|
from .mfcc import Mfcc, OnlineMfcc
|
||||||
from .plp import Plp
|
from .plp import OnlinePlp, Plp
|
||||||
from .spectrogram import Spectrogram
|
from .spectrogram import Spectrogram
|
||||||
|
@ -4,9 +4,20 @@
|
|||||||
import _kaldifeat
|
import _kaldifeat
|
||||||
|
|
||||||
from .offline_feature import OfflineFeature
|
from .offline_feature import OfflineFeature
|
||||||
|
from .online_feature import OnlineFeature
|
||||||
|
|
||||||
|
|
||||||
class Fbank(OfflineFeature):
|
class Fbank(OfflineFeature):
|
||||||
def __init__(self, opts: _kaldifeat.FbankOptions):
|
def __init__(self, opts: _kaldifeat.FbankOptions):
|
||||||
super().__init__(opts)
|
super().__init__(opts)
|
||||||
self.computer = _kaldifeat.Fbank(opts)
|
self.computer = _kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineFbank(OnlineFeature):
|
||||||
|
def __init__(self, opts: _kaldifeat.FbankOptions):
|
||||||
|
super().__init__(opts)
|
||||||
|
self.computer = _kaldifeat.OnlineFbank(opts)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.opts = _kaldifeat.FbankOptions.from_dict(state)
|
||||||
|
self.computer = _kaldifeat.OnlineFbank(self.opts)
|
||||||
|
@ -4,9 +4,20 @@
|
|||||||
import _kaldifeat
|
import _kaldifeat
|
||||||
|
|
||||||
from .offline_feature import OfflineFeature
|
from .offline_feature import OfflineFeature
|
||||||
|
from .online_feature import OnlineFeature
|
||||||
|
|
||||||
|
|
||||||
class Mfcc(OfflineFeature):
|
class Mfcc(OfflineFeature):
|
||||||
def __init__(self, opts: _kaldifeat.MfccOptions):
|
def __init__(self, opts: _kaldifeat.MfccOptions):
|
||||||
super().__init__(opts)
|
super().__init__(opts)
|
||||||
self.computer = _kaldifeat.Mfcc(opts)
|
self.computer = _kaldifeat.Mfcc(opts)
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineMfcc(OnlineFeature):
|
||||||
|
def __init__(self, opts: _kaldifeat.MfccOptions):
|
||||||
|
super().__init__(opts)
|
||||||
|
self.computer = _kaldifeat.OnlineMfcc(opts)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.opts = _kaldifeat.MfccOptions.from_dict(state)
|
||||||
|
self.computer = _kaldifeat.OnlineMfcc(self.opts)
|
||||||
|
95
kaldifeat/python/kaldifeat/online_feature.py
Normal file
95
kaldifeat/python/kaldifeat/online_feature.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
# Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineFeature(object):
|
||||||
|
"""Offline feature is a base class of other feature computers,
|
||||||
|
e.g., Fbank, Mfcc.
|
||||||
|
|
||||||
|
This class has two fields:
|
||||||
|
|
||||||
|
(1) opts. It contains the options for the feature computer.
|
||||||
|
(2) computer. The actual feature computer. It should be
|
||||||
|
instantiated by subclasses.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
It supports only CPU at present.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, opts):
|
||||||
|
assert opts.device.type == "cpu"
|
||||||
|
|
||||||
|
self.opts = opts
|
||||||
|
|
||||||
|
# self.computer is expected to be set by subclasses
|
||||||
|
self.computer = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_frames_ready(self) -> int:
|
||||||
|
"""Return the number of ready frames.
|
||||||
|
|
||||||
|
It can be updated by :method:`accept_waveform`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If you set ``opts.frame_opts.max_feature_vectors``, then
|
||||||
|
the valid frame indexes are in the range.
|
||||||
|
``[num_frames_ready - max_feature_vectors, num_frames_ready)``
|
||||||
|
|
||||||
|
If you leave ``opts.frame_opts.max_feature_vectors`` to its default
|
||||||
|
value, then the range is ``[0, num_frames_ready)``
|
||||||
|
"""
|
||||||
|
return self.computer.num_frames_ready
|
||||||
|
|
||||||
|
def is_last_frame(self, frame: int) -> bool:
|
||||||
|
"""Return True if the given frame is the last frame."""
|
||||||
|
return self.computer.is_last_frame(frame)
|
||||||
|
|
||||||
|
def get_frame(self, frame: int) -> torch.Tensor:
|
||||||
|
"""Get the frame by its index.
|
||||||
|
Args:
|
||||||
|
frame:
|
||||||
|
The frame index. If ``opts.frame_opts.max_feature_vectors`` is
|
||||||
|
-1, then its valid values are in the range
|
||||||
|
``[0, num_frames_ready)``. Otherwise, the range is
|
||||||
|
``[num_frames_ready - max_feature_vectors, num_frames_ready)``.
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor with shape ``(1, feature_dim)``
|
||||||
|
"""
|
||||||
|
return self.computer.get_frame(frame)
|
||||||
|
|
||||||
|
def get_frames(self, frames: List[int]) -> List[torch.Tensor]:
|
||||||
|
"""Get frames at the given frame indexes.
|
||||||
|
Args:
|
||||||
|
frames:
|
||||||
|
Frames whose indexes are in this list are returned.
|
||||||
|
Returns:
|
||||||
|
Return a list of feature frames at the given indexes.
|
||||||
|
"""
|
||||||
|
return self.computer.get_frames(frames)
|
||||||
|
|
||||||
|
def accept_waveform(
|
||||||
|
self, sampling_rate: float, waveform: torch.Tensor
|
||||||
|
) -> None:
|
||||||
|
"""Send audio samples to the extractor.
|
||||||
|
Args:
|
||||||
|
sampling_rate:
|
||||||
|
The sampling rate of the given audio samples. It has to be equal
|
||||||
|
to ``opts.frame_opts.samp_freq``.
|
||||||
|
waveform:
|
||||||
|
A 1-D tensor of shape (num_samples,). Its dtype is torch.float32
|
||||||
|
and has to be on CPU.
|
||||||
|
"""
|
||||||
|
self.computer.accept_waveform(sampling_rate, waveform)
|
||||||
|
|
||||||
|
def input_finished(self) -> None:
|
||||||
|
"""Tell the extractor that no more audio samples will be available.
|
||||||
|
After calling this function, you cannot invoke ``accept_waveform``
|
||||||
|
again.
|
||||||
|
"""
|
||||||
|
self.computer.input_finished()
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
return self.opts.as_dict()
|
@ -4,9 +4,20 @@
|
|||||||
import _kaldifeat
|
import _kaldifeat
|
||||||
|
|
||||||
from .offline_feature import OfflineFeature
|
from .offline_feature import OfflineFeature
|
||||||
|
from .online_feature import OnlineFeature
|
||||||
|
|
||||||
|
|
||||||
class Plp(OfflineFeature):
|
class Plp(OfflineFeature):
|
||||||
def __init__(self, opts: _kaldifeat.PlpOptions):
|
def __init__(self, opts: _kaldifeat.PlpOptions):
|
||||||
super().__init__(opts)
|
super().__init__(opts)
|
||||||
self.computer = _kaldifeat.Plp(opts)
|
self.computer = _kaldifeat.Plp(opts)
|
||||||
|
|
||||||
|
|
||||||
|
class OnlinePlp(OnlineFeature):
|
||||||
|
def __init__(self, opts: _kaldifeat.PlpOptions):
|
||||||
|
super().__init__(opts)
|
||||||
|
self.computer = _kaldifeat.OnlinePlp(opts)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.opts = _kaldifeat.PlpOptions.from_dict(state)
|
||||||
|
self.computer = _kaldifeat.OnlinePlp(self.opts)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -13,30 +13,88 @@ import kaldifeat
|
|||||||
cur_dir = Path(__file__).resolve().parent
|
cur_dir = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
|
||||||
|
def test_online_fbank(
|
||||||
|
opts: kaldifeat.FbankOptions,
|
||||||
|
wave: torch.Tensor,
|
||||||
|
cpu_features: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
opts:
|
||||||
|
The options to create the online fbank extractor.
|
||||||
|
wave:
|
||||||
|
The input 1-D waveform.
|
||||||
|
cpu_features:
|
||||||
|
The groud truth features that are computed offline
|
||||||
|
"""
|
||||||
|
online_fbank = kaldifeat.OnlineFbank(opts)
|
||||||
|
|
||||||
|
num_processed_frames = 0
|
||||||
|
i = 0 # current sample index to feed
|
||||||
|
while not online_fbank.is_last_frame(num_processed_frames - 1):
|
||||||
|
while num_processed_frames < online_fbank.num_frames_ready:
|
||||||
|
# There are new frames to be processed
|
||||||
|
frame = online_fbank.get_frame(num_processed_frames)
|
||||||
|
assert torch.allclose(
|
||||||
|
frame.squeeze(0), cpu_features[num_processed_frames]
|
||||||
|
)
|
||||||
|
num_processed_frames += 1
|
||||||
|
|
||||||
|
# Simulate streaming . Send a random number of audio samples
|
||||||
|
# to the extractor
|
||||||
|
num_samples = torch.randint(300, 1000, (1,)).item()
|
||||||
|
|
||||||
|
samples = wave[i : (i + num_samples)] # noqa
|
||||||
|
i += num_samples
|
||||||
|
if len(samples) == 0:
|
||||||
|
online_fbank.input_finished()
|
||||||
|
continue
|
||||||
|
|
||||||
|
online_fbank.accept_waveform(16000, samples)
|
||||||
|
|
||||||
|
assert num_processed_frames == online_fbank.num_frames_ready
|
||||||
|
assert num_processed_frames == cpu_features.size(0)
|
||||||
|
|
||||||
|
|
||||||
def test_fbank_default():
|
def test_fbank_default():
|
||||||
print("=====test_fbank_default=====")
|
print("=====test_fbank_default=====")
|
||||||
|
filename = cur_dir / "test_data/test.wav"
|
||||||
|
wave = read_wave(filename)
|
||||||
|
gt = read_ark_txt(cur_dir / "test_data/test.txt")
|
||||||
|
|
||||||
|
cpu_features = None
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
opts = kaldifeat.FbankOptions()
|
opts = kaldifeat.FbankOptions()
|
||||||
opts.device = device
|
opts.device = device
|
||||||
opts.frame_opts.dither = 0
|
opts.frame_opts.dither = 0
|
||||||
fbank = kaldifeat.Fbank(opts)
|
fbank = kaldifeat.Fbank(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
|
||||||
wave = read_wave(filename)
|
|
||||||
|
|
||||||
features = fbank(wave)
|
features = fbank(wave)
|
||||||
assert features.device.type == "cpu"
|
assert features.device.type == "cpu"
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test.txt")
|
|
||||||
assert torch.allclose(features, gt, rtol=1e-1)
|
assert torch.allclose(features, gt, rtol=1e-1)
|
||||||
|
if cpu_features is None:
|
||||||
|
cpu_features = features
|
||||||
|
|
||||||
wave = wave.to(device)
|
features = fbank(wave.to(device))
|
||||||
features = fbank(wave)
|
|
||||||
assert features.device == device
|
assert features.device == device
|
||||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||||
|
|
||||||
|
# Now for online fbank
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.max_feature_vectors = 100
|
||||||
|
|
||||||
|
test_online_fbank(opts, wave, cpu_features)
|
||||||
|
|
||||||
|
|
||||||
def test_fbank_htk():
|
def test_fbank_htk():
|
||||||
print("=====test_fbank_htk=====")
|
print("=====test_fbank_htk=====")
|
||||||
|
filename = cur_dir / "test_data/test.wav"
|
||||||
|
wave = read_wave(filename)
|
||||||
|
gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
|
||||||
|
|
||||||
|
cpu_features = None
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
opts = kaldifeat.FbankOptions()
|
opts = kaldifeat.FbankOptions()
|
||||||
@ -46,22 +104,32 @@ def test_fbank_htk():
|
|||||||
opts.htk_compat = True
|
opts.htk_compat = True
|
||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
fbank = kaldifeat.Fbank(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
|
||||||
wave = read_wave(filename)
|
|
||||||
|
|
||||||
features = fbank(wave)
|
features = fbank(wave)
|
||||||
assert features.device.type == "cpu"
|
assert features.device.type == "cpu"
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
|
|
||||||
assert torch.allclose(features, gt, rtol=1e-1)
|
assert torch.allclose(features, gt, rtol=1e-1)
|
||||||
|
if cpu_features is None:
|
||||||
|
cpu_features = features
|
||||||
|
|
||||||
wave = wave.to(device)
|
features = fbank(wave.to(device))
|
||||||
features = fbank(wave)
|
|
||||||
assert features.device == device
|
assert features.device == device
|
||||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||||
|
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.use_energy = True
|
||||||
|
opts.htk_compat = True
|
||||||
|
|
||||||
|
test_online_fbank(opts, wave, cpu_features)
|
||||||
|
|
||||||
|
|
||||||
def test_fbank_with_energy():
|
def test_fbank_with_energy():
|
||||||
print("=====test_fbank_with_energy=====")
|
print("=====test_fbank_with_energy=====")
|
||||||
|
filename = cur_dir / "test_data/test.wav"
|
||||||
|
wave = read_wave(filename)
|
||||||
|
gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
|
||||||
|
|
||||||
|
cpu_features = None
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
opts = kaldifeat.FbankOptions()
|
opts = kaldifeat.FbankOptions()
|
||||||
@ -70,22 +138,31 @@ def test_fbank_with_energy():
|
|||||||
opts.use_energy = True
|
opts.use_energy = True
|
||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
fbank = kaldifeat.Fbank(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
|
||||||
wave = read_wave(filename)
|
|
||||||
|
|
||||||
features = fbank(wave)
|
features = fbank(wave)
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
|
|
||||||
assert torch.allclose(features, gt, rtol=1e-1)
|
assert torch.allclose(features, gt, rtol=1e-1)
|
||||||
assert features.device.type == "cpu"
|
assert features.device.type == "cpu"
|
||||||
|
if cpu_features is None:
|
||||||
|
cpu_features = features
|
||||||
|
|
||||||
wave = wave.to(device)
|
features = fbank(wave.to(device))
|
||||||
features = fbank(wave)
|
|
||||||
assert features.device == device
|
assert features.device == device
|
||||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||||
|
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.use_energy = True
|
||||||
|
|
||||||
|
test_online_fbank(opts, wave, cpu_features)
|
||||||
|
|
||||||
|
|
||||||
def test_fbank_40_bins():
|
def test_fbank_40_bins():
|
||||||
print("=====test_fbank_40_bins=====")
|
print("=====test_fbank_40_bins=====")
|
||||||
|
filename = cur_dir / "test_data/test.wav"
|
||||||
|
wave = read_wave(filename)
|
||||||
|
gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
|
||||||
|
|
||||||
|
cpu_features = None
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
opts = kaldifeat.FbankOptions()
|
opts = kaldifeat.FbankOptions()
|
||||||
@ -94,22 +171,31 @@ def test_fbank_40_bins():
|
|||||||
opts.mel_opts.num_bins = 40
|
opts.mel_opts.num_bins = 40
|
||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
fbank = kaldifeat.Fbank(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
|
||||||
wave = read_wave(filename)
|
|
||||||
|
|
||||||
features = fbank(wave)
|
features = fbank(wave)
|
||||||
assert features.device.type == "cpu"
|
assert features.device.type == "cpu"
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
|
|
||||||
assert torch.allclose(features, gt, rtol=1e-1)
|
assert torch.allclose(features, gt, rtol=1e-1)
|
||||||
|
if cpu_features is None:
|
||||||
|
cpu_features = features
|
||||||
|
|
||||||
wave = wave.to(device)
|
features = fbank(wave.to(device))
|
||||||
features = fbank(wave)
|
|
||||||
assert features.device == device
|
assert features.device == device
|
||||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||||
|
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.mel_opts.num_bins = 40
|
||||||
|
|
||||||
|
test_online_fbank(opts, wave, cpu_features)
|
||||||
|
|
||||||
|
|
||||||
def test_fbank_40_bins_no_snip_edges():
|
def test_fbank_40_bins_no_snip_edges():
|
||||||
print("=====test_fbank_40_bins_no_snip_edges=====")
|
print("=====test_fbank_40_bins_no_snip_edges=====")
|
||||||
|
filename = cur_dir / "test_data/test.wav"
|
||||||
|
wave = read_wave(filename)
|
||||||
|
gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
|
||||||
|
|
||||||
|
cpu_features = None
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
opts = kaldifeat.FbankOptions()
|
opts = kaldifeat.FbankOptions()
|
||||||
@ -119,19 +205,24 @@ def test_fbank_40_bins_no_snip_edges():
|
|||||||
opts.frame_opts.snip_edges = False
|
opts.frame_opts.snip_edges = False
|
||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
fbank = kaldifeat.Fbank(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
|
||||||
wave = read_wave(filename)
|
|
||||||
|
|
||||||
features = fbank(wave)
|
features = fbank(wave)
|
||||||
assert features.device.type == "cpu"
|
assert features.device.type == "cpu"
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
|
|
||||||
assert torch.allclose(features, gt, rtol=1e-1)
|
assert torch.allclose(features, gt, rtol=1e-1)
|
||||||
|
if cpu_features is None:
|
||||||
|
cpu_features = features
|
||||||
|
|
||||||
wave = wave.to(device)
|
features = fbank(wave.to(device))
|
||||||
features = fbank(wave)
|
|
||||||
assert features.device == device
|
assert features.device == device
|
||||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||||
|
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.mel_opts.num_bins = 40
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
|
||||||
|
test_online_fbank(opts, wave, cpu_features)
|
||||||
|
|
||||||
|
|
||||||
def test_fbank_chunk():
|
def test_fbank_chunk():
|
||||||
print("=====test_fbank_chunk=====")
|
print("=====test_fbank_chunk=====")
|
||||||
@ -223,6 +314,16 @@ def test_pickle():
|
|||||||
|
|
||||||
assert str(fbank.opts) == str(fbank2.opts)
|
assert str(fbank.opts) == str(fbank2.opts)
|
||||||
|
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.use_energy = True
|
||||||
|
opts.use_power = False
|
||||||
|
|
||||||
|
fbank = kaldifeat.OnlineFbank(opts)
|
||||||
|
data = pickle.dumps(fbank)
|
||||||
|
fbank2 = pickle.loads(data)
|
||||||
|
|
||||||
|
assert str(fbank.opts) == str(fbank2.opts)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_fbank_default()
|
test_fbank_default()
|
||||||
|
@ -13,24 +13,82 @@ import kaldifeat
|
|||||||
cur_dir = Path(__file__).resolve().parent
|
cur_dir = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
|
||||||
|
def test_online_mfcc(
|
||||||
|
opts: kaldifeat.MfccOptions,
|
||||||
|
wave: torch.Tensor,
|
||||||
|
cpu_features: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
opts:
|
||||||
|
The options to create the online mfcc extractor.
|
||||||
|
wave:
|
||||||
|
The input 1-D waveform.
|
||||||
|
cpu_features:
|
||||||
|
The groud truth features that are computed offline
|
||||||
|
"""
|
||||||
|
online_mfcc = kaldifeat.OnlineMfcc(opts)
|
||||||
|
|
||||||
|
num_processed_frames = 0
|
||||||
|
i = 0 # current sample index to feed
|
||||||
|
while not online_mfcc.is_last_frame(num_processed_frames - 1):
|
||||||
|
while num_processed_frames < online_mfcc.num_frames_ready:
|
||||||
|
# There are new frames to be processed
|
||||||
|
frame = online_mfcc.get_frame(num_processed_frames)
|
||||||
|
assert torch.allclose(
|
||||||
|
frame.squeeze(0), cpu_features[num_processed_frames], atol=1e-3
|
||||||
|
)
|
||||||
|
num_processed_frames += 1
|
||||||
|
|
||||||
|
# Simulate streaming . Send a random number of audio samples
|
||||||
|
# to the extractor
|
||||||
|
num_samples = torch.randint(300, 1000, (1,)).item()
|
||||||
|
|
||||||
|
samples = wave[i : (i + num_samples)] # noqa
|
||||||
|
i += num_samples
|
||||||
|
if len(samples) == 0:
|
||||||
|
online_mfcc.input_finished()
|
||||||
|
continue
|
||||||
|
|
||||||
|
online_mfcc.accept_waveform(16000, samples)
|
||||||
|
|
||||||
|
assert num_processed_frames == online_mfcc.num_frames_ready
|
||||||
|
assert num_processed_frames == cpu_features.size(0)
|
||||||
|
|
||||||
|
|
||||||
def test_mfcc_default():
|
def test_mfcc_default():
|
||||||
print("=====test_mfcc_default=====")
|
print("=====test_mfcc_default=====")
|
||||||
|
filename = cur_dir / "test_data/test.wav"
|
||||||
|
wave = read_wave(filename)
|
||||||
|
gt = read_ark_txt(cur_dir / "test_data/test-mfcc.txt")
|
||||||
|
|
||||||
|
cpu_features = None
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
opts = kaldifeat.MfccOptions()
|
opts = kaldifeat.MfccOptions()
|
||||||
opts.device = device
|
opts.device = device
|
||||||
opts.frame_opts.dither = 0
|
opts.frame_opts.dither = 0
|
||||||
mfcc = kaldifeat.Mfcc(opts)
|
mfcc = kaldifeat.Mfcc(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
|
||||||
wave = read_wave(filename).to(device)
|
|
||||||
|
|
||||||
features = mfcc(wave)
|
features = mfcc(wave.to(device))
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-mfcc.txt")
|
if device.type == "cpu":
|
||||||
|
cpu_features = features
|
||||||
|
|
||||||
assert torch.allclose(features.cpu(), gt, atol=1e-1)
|
assert torch.allclose(features.cpu(), gt, atol=1e-1)
|
||||||
|
|
||||||
|
opts = kaldifeat.MfccOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
|
||||||
|
test_online_mfcc(opts, wave, cpu_features)
|
||||||
|
|
||||||
|
|
||||||
def test_mfcc_no_snip_edges():
|
def test_mfcc_no_snip_edges():
|
||||||
print("=====test_mfcc_no_snip_edges=====")
|
print("=====test_mfcc_no_snip_edges=====")
|
||||||
|
filename = cur_dir / "test_data/test.wav"
|
||||||
|
wave = read_wave(filename)
|
||||||
|
gt = read_ark_txt(cur_dir / "test_data/test-mfcc-no-snip-edges.txt")
|
||||||
|
|
||||||
|
cpu_features = None
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
opts = kaldifeat.MfccOptions()
|
opts = kaldifeat.MfccOptions()
|
||||||
@ -39,13 +97,19 @@ def test_mfcc_no_snip_edges():
|
|||||||
opts.frame_opts.snip_edges = False
|
opts.frame_opts.snip_edges = False
|
||||||
|
|
||||||
mfcc = kaldifeat.Mfcc(opts)
|
mfcc = kaldifeat.Mfcc(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
|
||||||
wave = read_wave(filename).to(device)
|
|
||||||
|
|
||||||
features = mfcc(wave)
|
features = mfcc(wave.to(device))
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-mfcc-no-snip-edges.txt")
|
if device.type == "cpu":
|
||||||
|
cpu_features = features
|
||||||
|
|
||||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||||
|
|
||||||
|
opts = kaldifeat.MfccOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
|
||||||
|
test_online_mfcc(opts, wave, cpu_features)
|
||||||
|
|
||||||
|
|
||||||
def test_pickle():
|
def test_pickle():
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
@ -60,6 +124,16 @@ def test_pickle():
|
|||||||
|
|
||||||
assert str(mfcc.opts) == str(mfcc2.opts)
|
assert str(mfcc.opts) == str(mfcc2.opts)
|
||||||
|
|
||||||
|
opts = kaldifeat.MfccOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
|
||||||
|
mfcc = kaldifeat.OnlineMfcc(opts)
|
||||||
|
data = pickle.dumps(mfcc)
|
||||||
|
mfcc2 = pickle.loads(data)
|
||||||
|
|
||||||
|
assert str(mfcc.opts) == str(mfcc2.opts)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_mfcc_default()
|
test_mfcc_default()
|
||||||
|
@ -13,24 +13,82 @@ import kaldifeat
|
|||||||
cur_dir = Path(__file__).resolve().parent
|
cur_dir = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
|
||||||
|
def test_online_plp(
|
||||||
|
opts: kaldifeat.PlpOptions,
|
||||||
|
wave: torch.Tensor,
|
||||||
|
cpu_features: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
opts:
|
||||||
|
The options to create the online plp extractor.
|
||||||
|
wave:
|
||||||
|
The input 1-D waveform.
|
||||||
|
cpu_features:
|
||||||
|
The groud truth features that are computed offline
|
||||||
|
"""
|
||||||
|
online_plp = kaldifeat.OnlinePlp(opts)
|
||||||
|
|
||||||
|
num_processed_frames = 0
|
||||||
|
i = 0 # current sample index to feed
|
||||||
|
while not online_plp.is_last_frame(num_processed_frames - 1):
|
||||||
|
while num_processed_frames < online_plp.num_frames_ready:
|
||||||
|
# There are new frames to be processed
|
||||||
|
frame = online_plp.get_frame(num_processed_frames)
|
||||||
|
assert torch.allclose(
|
||||||
|
frame.squeeze(0), cpu_features[num_processed_frames], atol=1e-3
|
||||||
|
)
|
||||||
|
num_processed_frames += 1
|
||||||
|
|
||||||
|
# Simulate streaming . Send a random number of audio samples
|
||||||
|
# to the extractor
|
||||||
|
num_samples = torch.randint(300, 1000, (1,)).item()
|
||||||
|
|
||||||
|
samples = wave[i : (i + num_samples)] # noqa
|
||||||
|
i += num_samples
|
||||||
|
if len(samples) == 0:
|
||||||
|
online_plp.input_finished()
|
||||||
|
continue
|
||||||
|
|
||||||
|
online_plp.accept_waveform(16000, samples)
|
||||||
|
|
||||||
|
assert num_processed_frames == online_plp.num_frames_ready
|
||||||
|
assert num_processed_frames == cpu_features.size(0)
|
||||||
|
|
||||||
|
|
||||||
def test_plp_default():
|
def test_plp_default():
|
||||||
print("=====test_plp_default=====")
|
print("=====test_plp_default=====")
|
||||||
|
filename = cur_dir / "test_data/test.wav"
|
||||||
|
wave = read_wave(filename)
|
||||||
|
gt = read_ark_txt(cur_dir / "test_data/test-plp.txt")
|
||||||
|
|
||||||
|
cpu_features = None
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
opts = kaldifeat.PlpOptions()
|
opts = kaldifeat.PlpOptions()
|
||||||
opts.frame_opts.dither = 0
|
opts.frame_opts.dither = 0
|
||||||
opts.device = device
|
opts.device = device
|
||||||
plp = kaldifeat.Plp(opts)
|
plp = kaldifeat.Plp(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
|
||||||
wave = read_wave(filename).to(device)
|
|
||||||
|
|
||||||
features = plp(wave)
|
features = plp(wave.to(device))
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-plp.txt")
|
if device.type == "cpu":
|
||||||
|
cpu_features = features
|
||||||
|
|
||||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||||
|
|
||||||
|
opts = kaldifeat.PlpOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
|
||||||
|
test_online_plp(opts, wave, cpu_features)
|
||||||
|
|
||||||
|
|
||||||
def test_plp_no_snip_edges():
|
def test_plp_no_snip_edges():
|
||||||
print("=====test_plp_no_snip_edges=====")
|
print("=====test_plp_no_snip_edges=====")
|
||||||
|
filename = cur_dir / "test_data/test.wav"
|
||||||
|
wave = read_wave(filename)
|
||||||
|
gt = read_ark_txt(cur_dir / "test_data/test-plp-no-snip-edges.txt")
|
||||||
|
|
||||||
|
cpu_features = None
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
opts = kaldifeat.PlpOptions()
|
opts = kaldifeat.PlpOptions()
|
||||||
@ -39,16 +97,26 @@ def test_plp_no_snip_edges():
|
|||||||
opts.frame_opts.snip_edges = False
|
opts.frame_opts.snip_edges = False
|
||||||
|
|
||||||
plp = kaldifeat.Plp(opts)
|
plp = kaldifeat.Plp(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
|
||||||
wave = read_wave(filename).to(device)
|
|
||||||
|
|
||||||
features = plp(wave)
|
features = plp(wave.to(device))
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-plp-no-snip-edges.txt")
|
if device.type == "cpu":
|
||||||
|
cpu_features = features
|
||||||
assert torch.allclose(features.cpu(), gt, atol=1e-1)
|
assert torch.allclose(features.cpu(), gt, atol=1e-1)
|
||||||
|
|
||||||
|
opts = kaldifeat.PlpOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
|
||||||
|
test_online_plp(opts, wave, cpu_features)
|
||||||
|
|
||||||
|
|
||||||
def test_plp_htk_10_ceps():
|
def test_plp_htk_10_ceps():
|
||||||
print("=====test_plp_htk_10_ceps=====")
|
print("=====test_plp_htk_10_ceps=====")
|
||||||
|
filename = cur_dir / "test_data/test.wav"
|
||||||
|
wave = read_wave(filename)
|
||||||
|
gt = read_ark_txt(cur_dir / "test_data/test-plp-htk-10-ceps.txt")
|
||||||
|
|
||||||
|
cpu_features = None
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
opts = kaldifeat.PlpOptions()
|
opts = kaldifeat.PlpOptions()
|
||||||
@ -58,13 +126,19 @@ def test_plp_htk_10_ceps():
|
|||||||
opts.frame_opts.dither = 0
|
opts.frame_opts.dither = 0
|
||||||
|
|
||||||
plp = kaldifeat.Plp(opts)
|
plp = kaldifeat.Plp(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
|
||||||
wave = read_wave(filename).to(device)
|
|
||||||
|
|
||||||
features = plp(wave)
|
features = plp(wave.to(device))
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-plp-htk-10-ceps.txt")
|
if device.type == "cpu":
|
||||||
|
cpu_features = features
|
||||||
assert torch.allclose(features.cpu(), gt, atol=1e-1)
|
assert torch.allclose(features.cpu(), gt, atol=1e-1)
|
||||||
|
|
||||||
|
opts = kaldifeat.PlpOptions()
|
||||||
|
opts.htk_compat = True
|
||||||
|
opts.num_ceps = 10
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
|
||||||
|
test_online_plp(opts, wave, cpu_features)
|
||||||
|
|
||||||
|
|
||||||
def test_pickle():
|
def test_pickle():
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
@ -79,6 +153,16 @@ def test_pickle():
|
|||||||
|
|
||||||
assert str(plp.opts) == str(plp2.opts)
|
assert str(plp.opts) == str(plp2.opts)
|
||||||
|
|
||||||
|
opts = kaldifeat.PlpOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
|
||||||
|
plp = kaldifeat.OnlinePlp(opts)
|
||||||
|
data = pickle.dumps(plp)
|
||||||
|
plp2 = pickle.loads(data)
|
||||||
|
|
||||||
|
assert str(plp.opts) == str(plp2.opts)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_plp_default()
|
test_plp_default()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user