From 753f47e89daef82b5c9a98cb768b2c83724afa74 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 25 Feb 2021 17:26:27 +0800 Subject: [PATCH] add fbank computer. --- kaldifeat/csrc/CMakeLists.txt | 1 + kaldifeat/csrc/feature-fbank.cc | 106 +++++++++++++++++++++++++++++ kaldifeat/csrc/feature-fbank.h | 70 +++++++++++++++++++ kaldifeat/csrc/mel-computations.cc | 5 ++ kaldifeat/csrc/mel-computations.h | 2 + 5 files changed, 184 insertions(+) create mode 100644 kaldifeat/csrc/feature-fbank.cc create mode 100644 kaldifeat/csrc/feature-fbank.h diff --git a/kaldifeat/csrc/CMakeLists.txt b/kaldifeat/csrc/CMakeLists.txt index e370851..aa0d41b 100644 --- a/kaldifeat/csrc/CMakeLists.txt +++ b/kaldifeat/csrc/CMakeLists.txt @@ -1,4 +1,5 @@ set(kaldifeat_srcs + feature-fbank.cc feature-window.cc mel-computations.cc ) diff --git a/kaldifeat/csrc/feature-fbank.cc b/kaldifeat/csrc/feature-fbank.cc new file mode 100644 index 0000000..ba95330 --- /dev/null +++ b/kaldifeat/csrc/feature-fbank.cc @@ -0,0 +1,106 @@ +// kaldifeat/csrc/feature-fbank.cc +// +// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +// This file is copied/modified from kaldi/src/feat/feature-fbank.cc + +#include "kaldifeat/csrc/feature-fbank.h" + +#include + +#include "torch/fft.h" +#include "torch/torch.h" + +namespace kaldifeat { + +FbankComputer::FbankComputer(const FbankOptions &opts) : opts_(opts) { + if (opts.energy_floor > 0.0f) log_energy_floor_ = logf(opts.energy_floor); + + // We'll definitely need the filterbanks info for VTLN warping factor 1.0. + // [note: this call caches it.] + GetMelBanks(1.0f); +} + +FbankComputer::FbankComputer(const FbankComputer &other) + : opts_(other.opts_), + log_energy_floor_(other.log_energy_floor_), + mel_banks_(other.mel_banks_) { + for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter) + iter->second = new MelBanks(*(iter->second)); +} + +FbankComputer::~FbankComputer() { + for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter) + delete iter->second; +} + +const MelBanks *FbankComputer::GetMelBanks(float vtln_warp) { + MelBanks *this_mel_banks = nullptr; + + // std::map::iterator iter = mel_banks_.find(vtln_warp); + auto iter = mel_banks_.find(vtln_warp); + if (iter == mel_banks_.end()) { + this_mel_banks = new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp); + mel_banks_[vtln_warp] = this_mel_banks; + } else { + this_mel_banks = iter->second; + } + return this_mel_banks; +} + +// ans.shape [signal_frame.sizes()[0], this->Dim()] +torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy, + float vtln_warp, + const torch::Tensor &signal_frame) { + const MelBanks &mel_banks = *(GetMelBanks(vtln_warp)); + + KALDIFEAT_ASSERT(signal_frame.dim() == 2); + + KALDIFEAT_ASSERT(signal_frame.sizes()[1] == + opts_.frame_opts.PaddedWindowSize()); + + // torch.finfo(torch.float32).eps + constexpr float kEps = 1.1920928955078125e-07f; + + // Compute energy after window function (not the raw one). + if (opts_.use_energy && !opts_.raw_energy) { + // signal_raw_log_energy = torch::max(signal_frame.pow(2).sum(1), + // kEps).log(); + signal_raw_log_energy = signal_frame.pow(2).sum(1).log(); + } + + // note spectrum is in magnitude, not power, because of `abs()` + torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs(); + + // Use power instead of magnitude if requested. + if (opts_.use_power) spectrum = spectrum.pow(2); + +#if 0 + int32_t mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0); + SubVector mel_energies(*feature, mel_offset, opts_.mel_opts.num_bins); +#endif + + torch::Tensor mel_energies = mel_banks.Compute(spectrum); + if (opts_.use_log_fbank) { + // Avoid log of zero (which should be prevented anyway by dithering). + // mel_energies = torch::max(mel_energies, kEps).log() + mel_energies = mel_energies.log(); + } + + // Copy energy as first value (or the last, if htk_compat == true). + if (opts_.use_energy) { +#if 0 + if (opts_.energy_floor > 0.0 && signal_raw_log_energy < log_energy_floor_) { + signal_raw_log_energy = log_energy_floor_; + } +#endif + int32_t energy_index = opts_.htk_compat ? opts_.mel_opts.num_bins : 0; + energy_index = 0; // TODO(fangjun): fix it + + mel_energies.index({"...", energy_index}) = signal_raw_log_energy; + } + + return mel_energies; +} + +} // namespace kaldifeat diff --git a/kaldifeat/csrc/feature-fbank.h b/kaldifeat/csrc/feature-fbank.h new file mode 100644 index 0000000..83a8165 --- /dev/null +++ b/kaldifeat/csrc/feature-fbank.h @@ -0,0 +1,70 @@ +// kaldifeat/csrc/feature-fbank.h +// +// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +// This file is copied/modified from kaldi/src/feat/feature-fbank.h + +#ifndef KALDIFEAT_CSRC_FEATURE_FBANK_H_ +#define KALDIFEAT_CSRC_FEATURE_FBANK_H_ + +#include + +#include "kaldifeat/csrc/feature-window.h" +#include "kaldifeat/csrc/mel-computations.h" + +namespace kaldifeat { + +struct FbankOptions { + FrameExtractionOptions frame_opts; + MelBanksOptions mel_opts; + // append an extra dimension with energy to the filter banks + bool use_energy = false; + float energy_floor = 0.0f; + + // If true, compute energy before preemphasis and windowing + bool raw_energy = true; + // If true, put energy last (if using energy) + bool htk_compat = false; + // if true (default), produce log-filterbank, else linear + bool use_log_fbank = true; + + // if true (default), use power in filterbank + // analysis, else magnitude. + bool use_power = true; + + FbankOptions() { mel_opts.num_bins = 23; } +}; + +class FbankComputer { + public: + explicit FbankComputer(const FbankOptions &opts); + ~FbankComputer(); + + FbankComputer &operator=(const FbankComputer &) = delete; + + FbankComputer(const FbankComputer &other); + + int32_t Dim() const { + return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0); + } + + bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; } + + const FrameExtractionOptions &GetFrameOptions() const { + return opts_.frame_opts; + } + + torch::Tensor Compute(torch::Tensor signal_raw_log_energy, float vtln_warp, + const torch::Tensor &signal_frame); + + private: + const MelBanks *GetMelBanks(float vtln_warp); + + FbankOptions opts_; + float log_energy_floor_; + std::map mel_banks_; // float is VTLN coefficient. +}; + +} // namespace kaldifeat + +#endif // KALDIFEAT_CSRC_FEATURE_FBANK_H_ diff --git a/kaldifeat/csrc/mel-computations.cc b/kaldifeat/csrc/mel-computations.cc index ccd3be3..0d4ca0b 100644 --- a/kaldifeat/csrc/mel-computations.cc +++ b/kaldifeat/csrc/mel-computations.cc @@ -177,4 +177,9 @@ MelBanks::MelBanks(const MelBanksOptions &opts, if (debug_) KALDIFEAT_LOG << bins_mat_; } +torch::Tensor MelBanks::Compute(const torch::Tensor &spectrum) const { + // TODO(fangjun): save a transposed version of `bins_mat_`. + return torch::mm(spectrum, bins_mat_.t()); +} + } // namespace kaldifeat diff --git a/kaldifeat/csrc/mel-computations.h b/kaldifeat/csrc/mel-computations.h index 3c26bfd..b64c27c 100644 --- a/kaldifeat/csrc/mel-computations.h +++ b/kaldifeat/csrc/mel-computations.h @@ -61,6 +61,8 @@ class MelBanks { int32_t NumBins() const { return static_cast(bins_mat_.sizes()[0]); } + torch::Tensor Compute(const torch::Tensor &spectrum) const; + private: // A 2-D matrix of shape [num_bins, num_fft_bins] torch::Tensor bins_mat_;