From 9bd6ee0c5fac915a551df8099d20c0dd450da040 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 25 Feb 2021 20:15:15 +0800 Subject: [PATCH] Implement ExtractWindow. --- kaldifeat/csrc/CMakeLists.txt | 3 + kaldifeat/csrc/feature-common-inl.h | 87 ++++++++++++++++++++++++++++ kaldifeat/csrc/feature-common.h | 61 ++++++++++++++++++++ kaldifeat/csrc/feature-fbank.cc | 8 +-- kaldifeat/csrc/feature-fbank.h | 5 ++ kaldifeat/csrc/feature-window.cc | 88 +++++++++++++++++++++++++++++ kaldifeat/csrc/feature-window.h | 31 +++++++++- kaldifeat/csrc/test_kaldifeat.cc | 37 ++++++++++++ 8 files changed, 314 insertions(+), 6 deletions(-) create mode 100644 kaldifeat/csrc/feature-common-inl.h create mode 100644 kaldifeat/csrc/feature-common.h create mode 100644 kaldifeat/csrc/test_kaldifeat.cc diff --git a/kaldifeat/csrc/CMakeLists.txt b/kaldifeat/csrc/CMakeLists.txt index aa0d41b..2f7c8bf 100644 --- a/kaldifeat/csrc/CMakeLists.txt +++ b/kaldifeat/csrc/CMakeLists.txt @@ -6,3 +6,6 @@ set(kaldifeat_srcs add_library(kaldifeat_core SHARED ${kaldifeat_srcs}) target_link_libraries(kaldifeat_core PUBLIC ${TORCH_LIBRARIES}) + +add_executable(test_kaldifeat test_kaldifeat.cc) +target_link_libraries(test_kaldifeat PRIVATE kaldifeat_core) diff --git a/kaldifeat/csrc/feature-common-inl.h b/kaldifeat/csrc/feature-common-inl.h new file mode 100644 index 0000000..5416cf3 --- /dev/null +++ b/kaldifeat/csrc/feature-common-inl.h @@ -0,0 +1,87 @@ +// kaldifeat/csrc/feature-common-inl.h +// +// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +// This file is copied/modified from kaldi/src/feat/feature-common-inl.h + +#ifndef KALDIFEAT_CSRC_FEATURE_COMMON_INL_H_ +#define KALDIFEAT_CSRC_FEATURE_COMMON_INL_H_ + +#include "kaldifeat/csrc/feature-window.h" + +namespace kaldifeat { + +template +torch::Tensor OfflineFeatureTpl::ComputeFeatures(const torch::Tensor &wave, + float vtln_warp) { + KALDIFEAT_ASSERT(wave.dim() == 1); + int32_t rows_out = NumFrames(wave.sizes()[0], computer_.GetFrameOptions()); + int32_t cols_out = computer_.Dim(); + + const FrameExtractionOptions &frame_opts = + computer_.GetFrameOptions().frame_opts; + + torch::Tensor strided_input = GetStrided(wave, frame_opts); + + if (frame_opts.dither != 0) + strided_input = Dither(strided_input, frame_opts.dither); + + if (frame_opts.remove_dc_offset) { + torch::Tensor row_means = strided_input.mean(1).unsqueeze(1); + strided_input -= row_means; + } + + bool use_raw_log_energy = computer_.NeedRawLogEnergy(); + torch::Tensor log_energy_pre_window; + + // torch.finfo(torch.float32).eps + constexpr float kEps = 1.1920928955078125e-07f; + + if (use_raw_log_energy) { + log_energy_pre_window = + torch::clamp_min(strided_input.pow(2).sum(1), kEps).log(); + } + + if (frame_opts.preemph_coeff != 0.0f) { + KALDIFEAT_ASSERT(frame_opts.preemph_coeff >= 0.0f && + frame_opts.preemph_coeff <= 1.0f); + + // right = strided_input[:, 1:] + torch::Tensor right = strided_input.index( + {"...", torch::indexing::Slice(1, torch::indexing::None, + torch::indexing::None)}); + + // current = strided_input[:, 0:-1] + torch::Tensor current = + strided_input.index({"...", torch::indexing::Slice(0, -1, 1)}); + + // strided_input[1:, :] = + // strided_input[:, 1:] - preemph_coeff * strided_input[:, 0:-1] + strided_input.index( + {"...", torch::indexing::Slice(1, torch::indexing::None, + torch::indexing::None)}) = + right - frame_opts.preemph_coeff * current; + + strided_input.index({"...", 0}) *= frame_opts.preemph_coeff; + } + + strided_input = feature_window_function_.Apply(strided_input); + +#if 0 + Vector window; // windowed waveform. + bool use_raw_log_energy = computer_.NeedRawLogEnergy(); + for (int32 r = 0; r < rows_out; r++) { // r is frame index. + BaseFloat raw_log_energy = 0.0; + ExtractWindow(0, wave, r, computer_.GetFrameOptions(), + feature_window_function_, &window, + (use_raw_log_energy ? &raw_log_energy : NULL)); + + SubVector output_row(*output, r); + computer_.Compute(raw_log_energy, vtln_warp, &window, &output_row); + } +#endif +} + +} // namespace kaldifeat + +#endif // KALDIFEAT_CSRC_FEATURE_COMMON_INL_H_ diff --git a/kaldifeat/csrc/feature-common.h b/kaldifeat/csrc/feature-common.h new file mode 100644 index 0000000..485ea31 --- /dev/null +++ b/kaldifeat/csrc/feature-common.h @@ -0,0 +1,61 @@ +// kaldifeat/csrc/feature-common.h +// +// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +// This file is copied/modified from kaldi/src/feat/feature-common.h + +#ifndef KALDIFEAT_CSRC_FEATURE_COMMON_H_ +#define KALDIFEAT_CSRC_FEATURE_COMMON_H_ + +#include "kaldifeat/csrc/feature-window.h" + +namespace kaldifeat { + +template +class OfflineFeatureTpl { + public: + using Options = typename F::Options; + + // Note: feature_window_function_ is the windowing function, which initialized + // using the options class, that we cache at this level. + OfflineFeatureTpl(const Options &opts) + : computer_(opts), + feature_window_function_(computer_.GetFrameOptions()) {} + + /** + Computes the features for one file (one sequence of features). + This is the newer interface where you specify the sample frequency + of the input waveform. + @param [in] wave The input waveform + @param [in] sample_freq The sampling frequency with which + 'wave' was sampled. + if sample_freq is higher than the frequency + specified in the config, we will downsample + the waveform, but if lower, it's an error. + @param [in] vtln_warp The VTLN warping factor (will normally + be 1.0) + @param [out] output The matrix of features, where the row-index + is the frame index. + */ + torch::Tensor ComputeFeatures(const torch::Tensor &wave, float vtln_warp); + + int32_t Dim() const { return computer_.Dim(); } + + // Copy constructor. + OfflineFeatureTpl(const OfflineFeatureTpl &other) + : computer_(other.computer_), + feature_window_function_(other.feature_window_function_) {} + + private: + // Disallow assignment. + OfflineFeatureTpl &operator=(const OfflineFeatureTpl &other); + + F computer_; + FeatureWindowFunction feature_window_function_; +}; + +} // namespace kaldifeat + +#include "kaldifeat/csrc/feature-common-inl.h" + +#endif // KALDIFEAT_CSRC_FEATURE_COMMON_H_ diff --git a/kaldifeat/csrc/feature-fbank.cc b/kaldifeat/csrc/feature-fbank.cc index ba95330..5b1c623 100644 --- a/kaldifeat/csrc/feature-fbank.cc +++ b/kaldifeat/csrc/feature-fbank.cc @@ -64,9 +64,8 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy, // Compute energy after window function (not the raw one). if (opts_.use_energy && !opts_.raw_energy) { - // signal_raw_log_energy = torch::max(signal_frame.pow(2).sum(1), - // kEps).log(); - signal_raw_log_energy = signal_frame.pow(2).sum(1).log(); + signal_raw_log_energy = + torch::clamp_min(signal_frame.pow(2).sum(1), kEps).log(); } // note spectrum is in magnitude, not power, because of `abs()` @@ -83,8 +82,7 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy, torch::Tensor mel_energies = mel_banks.Compute(spectrum); if (opts_.use_log_fbank) { // Avoid log of zero (which should be prevented anyway by dithering). - // mel_energies = torch::max(mel_energies, kEps).log() - mel_energies = mel_energies.log(); + mel_energies = torch::clamp_min(mel_energies, kEps).log(); } // Copy energy as first value (or the last, if htk_compat == true). diff --git a/kaldifeat/csrc/feature-fbank.h b/kaldifeat/csrc/feature-fbank.h index 83a8165..b1da6bb 100644 --- a/kaldifeat/csrc/feature-fbank.h +++ b/kaldifeat/csrc/feature-fbank.h @@ -9,6 +9,7 @@ #include +#include "kaldifeat/csrc/feature-common.h" #include "kaldifeat/csrc/feature-window.h" #include "kaldifeat/csrc/mel-computations.h" @@ -37,6 +38,8 @@ struct FbankOptions { class FbankComputer { public: + using Options = FbankOptions; + explicit FbankComputer(const FbankOptions &opts); ~FbankComputer(); @@ -65,6 +68,8 @@ class FbankComputer { std::map mel_banks_; // float is VTLN coefficient. }; +using Fbank = OfflineFeatureTpl; + } // namespace kaldifeat #endif // KALDIFEAT_CSRC_FEATURE_FBANK_H_ diff --git a/kaldifeat/csrc/feature-window.cc b/kaldifeat/csrc/feature-window.cc index a9afcfb..00ace47 100644 --- a/kaldifeat/csrc/feature-window.cc +++ b/kaldifeat/csrc/feature-window.cc @@ -47,6 +47,94 @@ FeatureWindowFunction::FeatureWindowFunction( KALDIFEAT_ERR << "Invalid window type " << opts.window_type; } } + + window = window.unsqueeze(0); +} + +torch::Tensor FeatureWindowFunction::Apply(const torch::Tensor &input) const { + KALDIFEAT_ASSERT(input.dim() == 2); + KALDIFEAT_ASSERT(input.sizes()[1] == window.sizes()[1]); + return input * window; +} + +static int64_t FirstSampleOfFrame(int32_t frame, + const FrameExtractionOptions &opts) { + int64_t frame_shift = opts.WindowShift(); + if (opts.snip_edges) { + return frame * frame_shift; + } else { + int64_t midpoint_of_frame = frame_shift * frame + frame_shift / 2, + beginning_of_frame = midpoint_of_frame - opts.WindowSize() / 2; + return beginning_of_frame; + } +} + +int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts, + bool flush /*= true*/) { + int64_t frame_shift = opts.WindowShift(); + int64_t frame_length = opts.WindowSize(); + if (opts.snip_edges) { + // with --snip-edges=true (the default), we use a HTK-like approach to + // determining the number of frames-- all frames have to fit completely into + // the waveform, and the first frame begins at sample zero. + if (num_samples < frame_length) + return 0; + else + return (1 + ((num_samples - frame_length) / frame_shift)); + // You can understand the expression above as follows: 'num_samples - + // frame_length' is how much room we have to shift the frame within the + // waveform; 'frame_shift' is how much we shift it each time; and the ratio + // is how many times we can shift it (integer arithmetic rounds down). + } else { + // if --snip-edges=false, the number of frames is determined by rounding the + // (file-length / frame-shift) to the nearest integer. The point of this + // formula is to make the number of frames an obvious and predictable + // function of the frame shift and signal length, which makes many + // segmentation-related questions simpler. + // + // Because integer division in C++ rounds toward zero, we add (half the + // frame-shift minus epsilon) before dividing, to have the effect of + // rounding towards the closest integer. + int32_t num_frames = (num_samples + (frame_shift / 2)) / frame_shift; + + if (flush) return num_frames; + + // note: 'end' always means the last plus one, i.e. one past the last. + int64_t end_sample_of_last_frame = + FirstSampleOfFrame(num_frames - 1, opts) + frame_length; + + // the following code is optimized more for clarity than efficiency. + // If flush == false, we can't output frames that extend past the end + // of the signal. + while (num_frames > 0 && end_sample_of_last_frame > num_samples) { + num_frames--; + end_sample_of_last_frame -= frame_shift; + } + return num_frames; + } +} + +torch::Tensor GetStrided(const torch::Tensor &wave, + const FrameExtractionOptions &opts) { + KALDIFEAT_ASSERT(wave.dim() == 1); + + std::vector strides = {opts.WindowShift() * wave.strides()[0], + wave.strides()[0]}; + + KALDIFEAT_ASSERT(opts.snip_edges == true); // FIXME(fangjun) + + int64_t num_samples = wave.sizes()[0]; + int32_t num_frames = NumFrames(num_samples, opts); + std::vector sizes = {num_frames, opts.WindowSize()}; + + return wave.as_strided(sizes, strides); +} + +torch::Tensor Dither(const torch::Tensor &wave, float dither_value) { + if (dither_value == 0.0f) wave; + + torch::Tensor rand_gauss = torch::randn_like(wave); + return wave + rand_gauss * dither_value; } } // namespace kaldifeat diff --git a/kaldifeat/csrc/feature-window.h b/kaldifeat/csrc/feature-window.h index d8229c5..586b43c 100644 --- a/kaldifeat/csrc/feature-window.h +++ b/kaldifeat/csrc/feature-window.h @@ -55,12 +55,41 @@ struct FrameExtractionOptions { } }; -struct FeatureWindowFunction { +class FeatureWindowFunction { + public: FeatureWindowFunction() = default; explicit FeatureWindowFunction(const FrameExtractionOptions &opts); + torch::Tensor Apply(const torch::Tensor &input) const; + + private: torch::Tensor window; }; +/** + This function returns the number of frames that we can extract from a wave + file with the given number of samples in it (assumed to have the same + sampling rate as specified in 'opts'). + + @param [in] num_samples The number of samples in the wave file. + @param [in] opts The frame-extraction options class + + @param [in] flush True if we are asserting that this number of samples + is 'all there is', false if we expecting more data to possibly come in. This + only makes a difference to the answer if opts.snips_edges + == false. For offline feature extraction you always want flush == + true. In an online-decoding context, once you know (or decide) + that no more data is coming in, you'd call it with flush == true at the end + to flush out any remaining data. +*/ +int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts, + bool flush = true); + +// return a 2-d tensor with shape [num_frames, opts.WindowSize()] +torch::Tensor GetStrided(const torch::Tensor &wave, + const FrameExtractionOptions &opts); + +torch::Tensor Dither(const torch::Tensor &wave, float dither_value); + } // namespace kaldifeat #endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_ diff --git a/kaldifeat/csrc/test_kaldifeat.cc b/kaldifeat/csrc/test_kaldifeat.cc new file mode 100644 index 0000000..a9bc12c --- /dev/null +++ b/kaldifeat/csrc/test_kaldifeat.cc @@ -0,0 +1,37 @@ +// kaldifeat/csrc/test_kaldifeat.cc +// +// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +#include "torch/torch.h" + +static void TestPreemph() { + torch::Tensor a = torch::arange(0, 12).reshape({3, 4}).to(torch::kFloat); + + torch::Tensor b = + a.index({"...", torch::indexing::Slice(1, torch::indexing::None, + torch::indexing::None)}); + + torch::Tensor c = a.index({"...", torch::indexing::Slice(0, -1, 1)}); + + a.index({"...", torch::indexing::Slice(1, torch::indexing::None, + torch::indexing::None)}) = + b - 0.97 * c; + + a.index({"...", 0}) *= 0.97; + + std::cout << a << "\n"; + std::cout << b << "\n"; + std::cout << "c: \n" << c << "\n"; + torch::Tensor d = b - 0.97 * c; + std::cout << d << "\n"; +} + +int main() { + torch::Tensor a = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat); + torch::Tensor b = torch::arange(1, 4).to(torch::kFloat).unsqueeze(0); + std::cout << a << "\n"; + std::cout << b << "\n"; + std::cout << a * b << "\n"; + + return 0; +}