mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-10 02:22:16 +00:00
Implement ExtractWindow.
This commit is contained in:
parent
753f47e89d
commit
9bd6ee0c5f
@ -6,3 +6,6 @@ set(kaldifeat_srcs
|
||||
|
||||
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
|
||||
target_link_libraries(kaldifeat_core PUBLIC ${TORCH_LIBRARIES})
|
||||
|
||||
add_executable(test_kaldifeat test_kaldifeat.cc)
|
||||
target_link_libraries(test_kaldifeat PRIVATE kaldifeat_core)
|
||||
|
87
kaldifeat/csrc/feature-common-inl.h
Normal file
87
kaldifeat/csrc/feature-common-inl.h
Normal file
@ -0,0 +1,87 @@
|
||||
// kaldifeat/csrc/feature-common-inl.h
|
||||
//
|
||||
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
// This file is copied/modified from kaldi/src/feat/feature-common-inl.h
|
||||
|
||||
#ifndef KALDIFEAT_CSRC_FEATURE_COMMON_INL_H_
|
||||
#define KALDIFEAT_CSRC_FEATURE_COMMON_INL_H_
|
||||
|
||||
#include "kaldifeat/csrc/feature-window.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
template <class F>
|
||||
torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
|
||||
float vtln_warp) {
|
||||
KALDIFEAT_ASSERT(wave.dim() == 1);
|
||||
int32_t rows_out = NumFrames(wave.sizes()[0], computer_.GetFrameOptions());
|
||||
int32_t cols_out = computer_.Dim();
|
||||
|
||||
const FrameExtractionOptions &frame_opts =
|
||||
computer_.GetFrameOptions().frame_opts;
|
||||
|
||||
torch::Tensor strided_input = GetStrided(wave, frame_opts);
|
||||
|
||||
if (frame_opts.dither != 0)
|
||||
strided_input = Dither(strided_input, frame_opts.dither);
|
||||
|
||||
if (frame_opts.remove_dc_offset) {
|
||||
torch::Tensor row_means = strided_input.mean(1).unsqueeze(1);
|
||||
strided_input -= row_means;
|
||||
}
|
||||
|
||||
bool use_raw_log_energy = computer_.NeedRawLogEnergy();
|
||||
torch::Tensor log_energy_pre_window;
|
||||
|
||||
// torch.finfo(torch.float32).eps
|
||||
constexpr float kEps = 1.1920928955078125e-07f;
|
||||
|
||||
if (use_raw_log_energy) {
|
||||
log_energy_pre_window =
|
||||
torch::clamp_min(strided_input.pow(2).sum(1), kEps).log();
|
||||
}
|
||||
|
||||
if (frame_opts.preemph_coeff != 0.0f) {
|
||||
KALDIFEAT_ASSERT(frame_opts.preemph_coeff >= 0.0f &&
|
||||
frame_opts.preemph_coeff <= 1.0f);
|
||||
|
||||
// right = strided_input[:, 1:]
|
||||
torch::Tensor right = strided_input.index(
|
||||
{"...", torch::indexing::Slice(1, torch::indexing::None,
|
||||
torch::indexing::None)});
|
||||
|
||||
// current = strided_input[:, 0:-1]
|
||||
torch::Tensor current =
|
||||
strided_input.index({"...", torch::indexing::Slice(0, -1, 1)});
|
||||
|
||||
// strided_input[1:, :] =
|
||||
// strided_input[:, 1:] - preemph_coeff * strided_input[:, 0:-1]
|
||||
strided_input.index(
|
||||
{"...", torch::indexing::Slice(1, torch::indexing::None,
|
||||
torch::indexing::None)}) =
|
||||
right - frame_opts.preemph_coeff * current;
|
||||
|
||||
strided_input.index({"...", 0}) *= frame_opts.preemph_coeff;
|
||||
}
|
||||
|
||||
strided_input = feature_window_function_.Apply(strided_input);
|
||||
|
||||
#if 0
|
||||
Vector<BaseFloat> window; // windowed waveform.
|
||||
bool use_raw_log_energy = computer_.NeedRawLogEnergy();
|
||||
for (int32 r = 0; r < rows_out; r++) { // r is frame index.
|
||||
BaseFloat raw_log_energy = 0.0;
|
||||
ExtractWindow(0, wave, r, computer_.GetFrameOptions(),
|
||||
feature_window_function_, &window,
|
||||
(use_raw_log_energy ? &raw_log_energy : NULL));
|
||||
|
||||
SubVector<BaseFloat> output_row(*output, r);
|
||||
computer_.Compute(raw_log_energy, vtln_warp, &window, &output_row);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
#endif // KALDIFEAT_CSRC_FEATURE_COMMON_INL_H_
|
61
kaldifeat/csrc/feature-common.h
Normal file
61
kaldifeat/csrc/feature-common.h
Normal file
@ -0,0 +1,61 @@
|
||||
// kaldifeat/csrc/feature-common.h
|
||||
//
|
||||
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
// This file is copied/modified from kaldi/src/feat/feature-common.h
|
||||
|
||||
#ifndef KALDIFEAT_CSRC_FEATURE_COMMON_H_
|
||||
#define KALDIFEAT_CSRC_FEATURE_COMMON_H_
|
||||
|
||||
#include "kaldifeat/csrc/feature-window.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
template <class F>
|
||||
class OfflineFeatureTpl {
|
||||
public:
|
||||
using Options = typename F::Options;
|
||||
|
||||
// Note: feature_window_function_ is the windowing function, which initialized
|
||||
// using the options class, that we cache at this level.
|
||||
OfflineFeatureTpl(const Options &opts)
|
||||
: computer_(opts),
|
||||
feature_window_function_(computer_.GetFrameOptions()) {}
|
||||
|
||||
/**
|
||||
Computes the features for one file (one sequence of features).
|
||||
This is the newer interface where you specify the sample frequency
|
||||
of the input waveform.
|
||||
@param [in] wave The input waveform
|
||||
@param [in] sample_freq The sampling frequency with which
|
||||
'wave' was sampled.
|
||||
if sample_freq is higher than the frequency
|
||||
specified in the config, we will downsample
|
||||
the waveform, but if lower, it's an error.
|
||||
@param [in] vtln_warp The VTLN warping factor (will normally
|
||||
be 1.0)
|
||||
@param [out] output The matrix of features, where the row-index
|
||||
is the frame index.
|
||||
*/
|
||||
torch::Tensor ComputeFeatures(const torch::Tensor &wave, float vtln_warp);
|
||||
|
||||
int32_t Dim() const { return computer_.Dim(); }
|
||||
|
||||
// Copy constructor.
|
||||
OfflineFeatureTpl(const OfflineFeatureTpl<F> &other)
|
||||
: computer_(other.computer_),
|
||||
feature_window_function_(other.feature_window_function_) {}
|
||||
|
||||
private:
|
||||
// Disallow assignment.
|
||||
OfflineFeatureTpl<F> &operator=(const OfflineFeatureTpl<F> &other);
|
||||
|
||||
F computer_;
|
||||
FeatureWindowFunction feature_window_function_;
|
||||
};
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
#include "kaldifeat/csrc/feature-common-inl.h"
|
||||
|
||||
#endif // KALDIFEAT_CSRC_FEATURE_COMMON_H_
|
@ -64,9 +64,8 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
|
||||
|
||||
// Compute energy after window function (not the raw one).
|
||||
if (opts_.use_energy && !opts_.raw_energy) {
|
||||
// signal_raw_log_energy = torch::max(signal_frame.pow(2).sum(1),
|
||||
// kEps).log();
|
||||
signal_raw_log_energy = signal_frame.pow(2).sum(1).log();
|
||||
signal_raw_log_energy =
|
||||
torch::clamp_min(signal_frame.pow(2).sum(1), kEps).log();
|
||||
}
|
||||
|
||||
// note spectrum is in magnitude, not power, because of `abs()`
|
||||
@ -83,8 +82,7 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
|
||||
torch::Tensor mel_energies = mel_banks.Compute(spectrum);
|
||||
if (opts_.use_log_fbank) {
|
||||
// Avoid log of zero (which should be prevented anyway by dithering).
|
||||
// mel_energies = torch::max(mel_energies, kEps).log()
|
||||
mel_energies = mel_energies.log();
|
||||
mel_energies = torch::clamp_min(mel_energies, kEps).log();
|
||||
}
|
||||
|
||||
// Copy energy as first value (or the last, if htk_compat == true).
|
||||
|
@ -9,6 +9,7 @@
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "kaldifeat/csrc/feature-common.h"
|
||||
#include "kaldifeat/csrc/feature-window.h"
|
||||
#include "kaldifeat/csrc/mel-computations.h"
|
||||
|
||||
@ -37,6 +38,8 @@ struct FbankOptions {
|
||||
|
||||
class FbankComputer {
|
||||
public:
|
||||
using Options = FbankOptions;
|
||||
|
||||
explicit FbankComputer(const FbankOptions &opts);
|
||||
~FbankComputer();
|
||||
|
||||
@ -65,6 +68,8 @@ class FbankComputer {
|
||||
std::map<float, MelBanks *> mel_banks_; // float is VTLN coefficient.
|
||||
};
|
||||
|
||||
using Fbank = OfflineFeatureTpl<FbankComputer>;
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
#endif // KALDIFEAT_CSRC_FEATURE_FBANK_H_
|
||||
|
@ -47,6 +47,94 @@ FeatureWindowFunction::FeatureWindowFunction(
|
||||
KALDIFEAT_ERR << "Invalid window type " << opts.window_type;
|
||||
}
|
||||
}
|
||||
|
||||
window = window.unsqueeze(0);
|
||||
}
|
||||
|
||||
torch::Tensor FeatureWindowFunction::Apply(const torch::Tensor &input) const {
|
||||
KALDIFEAT_ASSERT(input.dim() == 2);
|
||||
KALDIFEAT_ASSERT(input.sizes()[1] == window.sizes()[1]);
|
||||
return input * window;
|
||||
}
|
||||
|
||||
static int64_t FirstSampleOfFrame(int32_t frame,
|
||||
const FrameExtractionOptions &opts) {
|
||||
int64_t frame_shift = opts.WindowShift();
|
||||
if (opts.snip_edges) {
|
||||
return frame * frame_shift;
|
||||
} else {
|
||||
int64_t midpoint_of_frame = frame_shift * frame + frame_shift / 2,
|
||||
beginning_of_frame = midpoint_of_frame - opts.WindowSize() / 2;
|
||||
return beginning_of_frame;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts,
|
||||
bool flush /*= true*/) {
|
||||
int64_t frame_shift = opts.WindowShift();
|
||||
int64_t frame_length = opts.WindowSize();
|
||||
if (opts.snip_edges) {
|
||||
// with --snip-edges=true (the default), we use a HTK-like approach to
|
||||
// determining the number of frames-- all frames have to fit completely into
|
||||
// the waveform, and the first frame begins at sample zero.
|
||||
if (num_samples < frame_length)
|
||||
return 0;
|
||||
else
|
||||
return (1 + ((num_samples - frame_length) / frame_shift));
|
||||
// You can understand the expression above as follows: 'num_samples -
|
||||
// frame_length' is how much room we have to shift the frame within the
|
||||
// waveform; 'frame_shift' is how much we shift it each time; and the ratio
|
||||
// is how many times we can shift it (integer arithmetic rounds down).
|
||||
} else {
|
||||
// if --snip-edges=false, the number of frames is determined by rounding the
|
||||
// (file-length / frame-shift) to the nearest integer. The point of this
|
||||
// formula is to make the number of frames an obvious and predictable
|
||||
// function of the frame shift and signal length, which makes many
|
||||
// segmentation-related questions simpler.
|
||||
//
|
||||
// Because integer division in C++ rounds toward zero, we add (half the
|
||||
// frame-shift minus epsilon) before dividing, to have the effect of
|
||||
// rounding towards the closest integer.
|
||||
int32_t num_frames = (num_samples + (frame_shift / 2)) / frame_shift;
|
||||
|
||||
if (flush) return num_frames;
|
||||
|
||||
// note: 'end' always means the last plus one, i.e. one past the last.
|
||||
int64_t end_sample_of_last_frame =
|
||||
FirstSampleOfFrame(num_frames - 1, opts) + frame_length;
|
||||
|
||||
// the following code is optimized more for clarity than efficiency.
|
||||
// If flush == false, we can't output frames that extend past the end
|
||||
// of the signal.
|
||||
while (num_frames > 0 && end_sample_of_last_frame > num_samples) {
|
||||
num_frames--;
|
||||
end_sample_of_last_frame -= frame_shift;
|
||||
}
|
||||
return num_frames;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor GetStrided(const torch::Tensor &wave,
|
||||
const FrameExtractionOptions &opts) {
|
||||
KALDIFEAT_ASSERT(wave.dim() == 1);
|
||||
|
||||
std::vector<int64_t> strides = {opts.WindowShift() * wave.strides()[0],
|
||||
wave.strides()[0]};
|
||||
|
||||
KALDIFEAT_ASSERT(opts.snip_edges == true); // FIXME(fangjun)
|
||||
|
||||
int64_t num_samples = wave.sizes()[0];
|
||||
int32_t num_frames = NumFrames(num_samples, opts);
|
||||
std::vector<int64_t> sizes = {num_frames, opts.WindowSize()};
|
||||
|
||||
return wave.as_strided(sizes, strides);
|
||||
}
|
||||
|
||||
torch::Tensor Dither(const torch::Tensor &wave, float dither_value) {
|
||||
if (dither_value == 0.0f) wave;
|
||||
|
||||
torch::Tensor rand_gauss = torch::randn_like(wave);
|
||||
return wave + rand_gauss * dither_value;
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
@ -55,12 +55,41 @@ struct FrameExtractionOptions {
|
||||
}
|
||||
};
|
||||
|
||||
struct FeatureWindowFunction {
|
||||
class FeatureWindowFunction {
|
||||
public:
|
||||
FeatureWindowFunction() = default;
|
||||
explicit FeatureWindowFunction(const FrameExtractionOptions &opts);
|
||||
torch::Tensor Apply(const torch::Tensor &input) const;
|
||||
|
||||
private:
|
||||
torch::Tensor window;
|
||||
};
|
||||
|
||||
/**
|
||||
This function returns the number of frames that we can extract from a wave
|
||||
file with the given number of samples in it (assumed to have the same
|
||||
sampling rate as specified in 'opts').
|
||||
|
||||
@param [in] num_samples The number of samples in the wave file.
|
||||
@param [in] opts The frame-extraction options class
|
||||
|
||||
@param [in] flush True if we are asserting that this number of samples
|
||||
is 'all there is', false if we expecting more data to possibly come in. This
|
||||
only makes a difference to the answer if opts.snips_edges
|
||||
== false. For offline feature extraction you always want flush ==
|
||||
true. In an online-decoding context, once you know (or decide)
|
||||
that no more data is coming in, you'd call it with flush == true at the end
|
||||
to flush out any remaining data.
|
||||
*/
|
||||
int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts,
|
||||
bool flush = true);
|
||||
|
||||
// return a 2-d tensor with shape [num_frames, opts.WindowSize()]
|
||||
torch::Tensor GetStrided(const torch::Tensor &wave,
|
||||
const FrameExtractionOptions &opts);
|
||||
|
||||
torch::Tensor Dither(const torch::Tensor &wave, float dither_value);
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
#endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_
|
||||
|
37
kaldifeat/csrc/test_kaldifeat.cc
Normal file
37
kaldifeat/csrc/test_kaldifeat.cc
Normal file
@ -0,0 +1,37 @@
|
||||
// kaldifeat/csrc/test_kaldifeat.cc
|
||||
//
|
||||
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
#include "torch/torch.h"
|
||||
|
||||
static void TestPreemph() {
|
||||
torch::Tensor a = torch::arange(0, 12).reshape({3, 4}).to(torch::kFloat);
|
||||
|
||||
torch::Tensor b =
|
||||
a.index({"...", torch::indexing::Slice(1, torch::indexing::None,
|
||||
torch::indexing::None)});
|
||||
|
||||
torch::Tensor c = a.index({"...", torch::indexing::Slice(0, -1, 1)});
|
||||
|
||||
a.index({"...", torch::indexing::Slice(1, torch::indexing::None,
|
||||
torch::indexing::None)}) =
|
||||
b - 0.97 * c;
|
||||
|
||||
a.index({"...", 0}) *= 0.97;
|
||||
|
||||
std::cout << a << "\n";
|
||||
std::cout << b << "\n";
|
||||
std::cout << "c: \n" << c << "\n";
|
||||
torch::Tensor d = b - 0.97 * c;
|
||||
std::cout << d << "\n";
|
||||
}
|
||||
|
||||
int main() {
|
||||
torch::Tensor a = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat);
|
||||
torch::Tensor b = torch::arange(1, 4).to(torch::kFloat).unsqueeze(0);
|
||||
std::cout << a << "\n";
|
||||
std::cout << b << "\n";
|
||||
std::cout << a * b << "\n";
|
||||
|
||||
return 0;
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user