Implement ExtractWindow.

This commit is contained in:
Fangjun Kuang 2021-02-25 20:15:15 +08:00
parent 753f47e89d
commit 9bd6ee0c5f
8 changed files with 314 additions and 6 deletions

View File

@ -6,3 +6,6 @@ set(kaldifeat_srcs
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
target_link_libraries(kaldifeat_core PUBLIC ${TORCH_LIBRARIES})
add_executable(test_kaldifeat test_kaldifeat.cc)
target_link_libraries(test_kaldifeat PRIVATE kaldifeat_core)

View File

@ -0,0 +1,87 @@
// kaldifeat/csrc/feature-common-inl.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-common-inl.h
#ifndef KALDIFEAT_CSRC_FEATURE_COMMON_INL_H_
#define KALDIFEAT_CSRC_FEATURE_COMMON_INL_H_
#include "kaldifeat/csrc/feature-window.h"
namespace kaldifeat {
template <class F>
torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
float vtln_warp) {
KALDIFEAT_ASSERT(wave.dim() == 1);
int32_t rows_out = NumFrames(wave.sizes()[0], computer_.GetFrameOptions());
int32_t cols_out = computer_.Dim();
const FrameExtractionOptions &frame_opts =
computer_.GetFrameOptions().frame_opts;
torch::Tensor strided_input = GetStrided(wave, frame_opts);
if (frame_opts.dither != 0)
strided_input = Dither(strided_input, frame_opts.dither);
if (frame_opts.remove_dc_offset) {
torch::Tensor row_means = strided_input.mean(1).unsqueeze(1);
strided_input -= row_means;
}
bool use_raw_log_energy = computer_.NeedRawLogEnergy();
torch::Tensor log_energy_pre_window;
// torch.finfo(torch.float32).eps
constexpr float kEps = 1.1920928955078125e-07f;
if (use_raw_log_energy) {
log_energy_pre_window =
torch::clamp_min(strided_input.pow(2).sum(1), kEps).log();
}
if (frame_opts.preemph_coeff != 0.0f) {
KALDIFEAT_ASSERT(frame_opts.preemph_coeff >= 0.0f &&
frame_opts.preemph_coeff <= 1.0f);
// right = strided_input[:, 1:]
torch::Tensor right = strided_input.index(
{"...", torch::indexing::Slice(1, torch::indexing::None,
torch::indexing::None)});
// current = strided_input[:, 0:-1]
torch::Tensor current =
strided_input.index({"...", torch::indexing::Slice(0, -1, 1)});
// strided_input[1:, :] =
// strided_input[:, 1:] - preemph_coeff * strided_input[:, 0:-1]
strided_input.index(
{"...", torch::indexing::Slice(1, torch::indexing::None,
torch::indexing::None)}) =
right - frame_opts.preemph_coeff * current;
strided_input.index({"...", 0}) *= frame_opts.preemph_coeff;
}
strided_input = feature_window_function_.Apply(strided_input);
#if 0
Vector<BaseFloat> window; // windowed waveform.
bool use_raw_log_energy = computer_.NeedRawLogEnergy();
for (int32 r = 0; r < rows_out; r++) { // r is frame index.
BaseFloat raw_log_energy = 0.0;
ExtractWindow(0, wave, r, computer_.GetFrameOptions(),
feature_window_function_, &window,
(use_raw_log_energy ? &raw_log_energy : NULL));
SubVector<BaseFloat> output_row(*output, r);
computer_.Compute(raw_log_energy, vtln_warp, &window, &output_row);
}
#endif
}
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_FEATURE_COMMON_INL_H_

View File

@ -0,0 +1,61 @@
// kaldifeat/csrc/feature-common.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-common.h
#ifndef KALDIFEAT_CSRC_FEATURE_COMMON_H_
#define KALDIFEAT_CSRC_FEATURE_COMMON_H_
#include "kaldifeat/csrc/feature-window.h"
namespace kaldifeat {
template <class F>
class OfflineFeatureTpl {
public:
using Options = typename F::Options;
// Note: feature_window_function_ is the windowing function, which initialized
// using the options class, that we cache at this level.
OfflineFeatureTpl(const Options &opts)
: computer_(opts),
feature_window_function_(computer_.GetFrameOptions()) {}
/**
Computes the features for one file (one sequence of features).
This is the newer interface where you specify the sample frequency
of the input waveform.
@param [in] wave The input waveform
@param [in] sample_freq The sampling frequency with which
'wave' was sampled.
if sample_freq is higher than the frequency
specified in the config, we will downsample
the waveform, but if lower, it's an error.
@param [in] vtln_warp The VTLN warping factor (will normally
be 1.0)
@param [out] output The matrix of features, where the row-index
is the frame index.
*/
torch::Tensor ComputeFeatures(const torch::Tensor &wave, float vtln_warp);
int32_t Dim() const { return computer_.Dim(); }
// Copy constructor.
OfflineFeatureTpl(const OfflineFeatureTpl<F> &other)
: computer_(other.computer_),
feature_window_function_(other.feature_window_function_) {}
private:
// Disallow assignment.
OfflineFeatureTpl<F> &operator=(const OfflineFeatureTpl<F> &other);
F computer_;
FeatureWindowFunction feature_window_function_;
};
} // namespace kaldifeat
#include "kaldifeat/csrc/feature-common-inl.h"
#endif // KALDIFEAT_CSRC_FEATURE_COMMON_H_

View File

@ -64,9 +64,8 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
// Compute energy after window function (not the raw one).
if (opts_.use_energy && !opts_.raw_energy) {
// signal_raw_log_energy = torch::max(signal_frame.pow(2).sum(1),
// kEps).log();
signal_raw_log_energy = signal_frame.pow(2).sum(1).log();
signal_raw_log_energy =
torch::clamp_min(signal_frame.pow(2).sum(1), kEps).log();
}
// note spectrum is in magnitude, not power, because of `abs()`
@ -83,8 +82,7 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
torch::Tensor mel_energies = mel_banks.Compute(spectrum);
if (opts_.use_log_fbank) {
// Avoid log of zero (which should be prevented anyway by dithering).
// mel_energies = torch::max(mel_energies, kEps).log()
mel_energies = mel_energies.log();
mel_energies = torch::clamp_min(mel_energies, kEps).log();
}
// Copy energy as first value (or the last, if htk_compat == true).

View File

@ -9,6 +9,7 @@
#include <map>
#include "kaldifeat/csrc/feature-common.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/mel-computations.h"
@ -37,6 +38,8 @@ struct FbankOptions {
class FbankComputer {
public:
using Options = FbankOptions;
explicit FbankComputer(const FbankOptions &opts);
~FbankComputer();
@ -65,6 +68,8 @@ class FbankComputer {
std::map<float, MelBanks *> mel_banks_; // float is VTLN coefficient.
};
using Fbank = OfflineFeatureTpl<FbankComputer>;
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_FEATURE_FBANK_H_

View File

@ -47,6 +47,94 @@ FeatureWindowFunction::FeatureWindowFunction(
KALDIFEAT_ERR << "Invalid window type " << opts.window_type;
}
}
window = window.unsqueeze(0);
}
torch::Tensor FeatureWindowFunction::Apply(const torch::Tensor &input) const {
KALDIFEAT_ASSERT(input.dim() == 2);
KALDIFEAT_ASSERT(input.sizes()[1] == window.sizes()[1]);
return input * window;
}
static int64_t FirstSampleOfFrame(int32_t frame,
const FrameExtractionOptions &opts) {
int64_t frame_shift = opts.WindowShift();
if (opts.snip_edges) {
return frame * frame_shift;
} else {
int64_t midpoint_of_frame = frame_shift * frame + frame_shift / 2,
beginning_of_frame = midpoint_of_frame - opts.WindowSize() / 2;
return beginning_of_frame;
}
}
int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts,
bool flush /*= true*/) {
int64_t frame_shift = opts.WindowShift();
int64_t frame_length = opts.WindowSize();
if (opts.snip_edges) {
// with --snip-edges=true (the default), we use a HTK-like approach to
// determining the number of frames-- all frames have to fit completely into
// the waveform, and the first frame begins at sample zero.
if (num_samples < frame_length)
return 0;
else
return (1 + ((num_samples - frame_length) / frame_shift));
// You can understand the expression above as follows: 'num_samples -
// frame_length' is how much room we have to shift the frame within the
// waveform; 'frame_shift' is how much we shift it each time; and the ratio
// is how many times we can shift it (integer arithmetic rounds down).
} else {
// if --snip-edges=false, the number of frames is determined by rounding the
// (file-length / frame-shift) to the nearest integer. The point of this
// formula is to make the number of frames an obvious and predictable
// function of the frame shift and signal length, which makes many
// segmentation-related questions simpler.
//
// Because integer division in C++ rounds toward zero, we add (half the
// frame-shift minus epsilon) before dividing, to have the effect of
// rounding towards the closest integer.
int32_t num_frames = (num_samples + (frame_shift / 2)) / frame_shift;
if (flush) return num_frames;
// note: 'end' always means the last plus one, i.e. one past the last.
int64_t end_sample_of_last_frame =
FirstSampleOfFrame(num_frames - 1, opts) + frame_length;
// the following code is optimized more for clarity than efficiency.
// If flush == false, we can't output frames that extend past the end
// of the signal.
while (num_frames > 0 && end_sample_of_last_frame > num_samples) {
num_frames--;
end_sample_of_last_frame -= frame_shift;
}
return num_frames;
}
}
torch::Tensor GetStrided(const torch::Tensor &wave,
const FrameExtractionOptions &opts) {
KALDIFEAT_ASSERT(wave.dim() == 1);
std::vector<int64_t> strides = {opts.WindowShift() * wave.strides()[0],
wave.strides()[0]};
KALDIFEAT_ASSERT(opts.snip_edges == true); // FIXME(fangjun)
int64_t num_samples = wave.sizes()[0];
int32_t num_frames = NumFrames(num_samples, opts);
std::vector<int64_t> sizes = {num_frames, opts.WindowSize()};
return wave.as_strided(sizes, strides);
}
torch::Tensor Dither(const torch::Tensor &wave, float dither_value) {
if (dither_value == 0.0f) wave;
torch::Tensor rand_gauss = torch::randn_like(wave);
return wave + rand_gauss * dither_value;
}
} // namespace kaldifeat

View File

@ -55,12 +55,41 @@ struct FrameExtractionOptions {
}
};
struct FeatureWindowFunction {
class FeatureWindowFunction {
public:
FeatureWindowFunction() = default;
explicit FeatureWindowFunction(const FrameExtractionOptions &opts);
torch::Tensor Apply(const torch::Tensor &input) const;
private:
torch::Tensor window;
};
/**
This function returns the number of frames that we can extract from a wave
file with the given number of samples in it (assumed to have the same
sampling rate as specified in 'opts').
@param [in] num_samples The number of samples in the wave file.
@param [in] opts The frame-extraction options class
@param [in] flush True if we are asserting that this number of samples
is 'all there is', false if we expecting more data to possibly come in. This
only makes a difference to the answer if opts.snips_edges
== false. For offline feature extraction you always want flush ==
true. In an online-decoding context, once you know (or decide)
that no more data is coming in, you'd call it with flush == true at the end
to flush out any remaining data.
*/
int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts,
bool flush = true);
// return a 2-d tensor with shape [num_frames, opts.WindowSize()]
torch::Tensor GetStrided(const torch::Tensor &wave,
const FrameExtractionOptions &opts);
torch::Tensor Dither(const torch::Tensor &wave, float dither_value);
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_

View File

@ -0,0 +1,37 @@
// kaldifeat/csrc/test_kaldifeat.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#include "torch/torch.h"
static void TestPreemph() {
torch::Tensor a = torch::arange(0, 12).reshape({3, 4}).to(torch::kFloat);
torch::Tensor b =
a.index({"...", torch::indexing::Slice(1, torch::indexing::None,
torch::indexing::None)});
torch::Tensor c = a.index({"...", torch::indexing::Slice(0, -1, 1)});
a.index({"...", torch::indexing::Slice(1, torch::indexing::None,
torch::indexing::None)}) =
b - 0.97 * c;
a.index({"...", 0}) *= 0.97;
std::cout << a << "\n";
std::cout << b << "\n";
std::cout << "c: \n" << c << "\n";
torch::Tensor d = b - 0.97 * c;
std::cout << d << "\n";
}
int main() {
torch::Tensor a = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat);
torch::Tensor b = torch::arange(1, 4).to(torch::kFloat).unsqueeze(0);
std::cout << a << "\n";
std::cout << b << "\n";
std::cout << a * b << "\n";
return 0;
}