mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 01:52:39 +00:00
145 lines
4.9 KiB
C++
145 lines
4.9 KiB
C++
// kaldifeat/csrc/mel-computations.h
|
|
//
|
|
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
|
//
|
|
// This file is copied/modified from kaldi/src/feat/mel-computations.h
|
|
|
|
#include <cmath>
|
|
#include <string>
|
|
|
|
#include "kaldifeat/csrc/feature-window.h"
|
|
|
|
#ifndef KALDIFEAT_CSRC_MEL_COMPUTATIONS_H_
|
|
#define KALDIFEAT_CSRC_MEL_COMPUTATIONS_H_
|
|
|
|
namespace kaldifeat {
|
|
|
|
struct MelBanksOptions {
|
|
int32_t num_bins = 25; // e.g. 25; number of triangular bins
|
|
float low_freq = 20; // e.g. 20; lower frequency cutoff
|
|
|
|
// an upper frequency cutoff; 0 -> no cutoff, negative
|
|
// ->added to the Nyquist frequency to get the cutoff.
|
|
float high_freq = 0;
|
|
|
|
float vtln_low = 100; // vtln lower cutoff of warping function.
|
|
|
|
// vtln upper cutoff of warping function: if negative, added
|
|
// to the Nyquist frequency to get the cutoff.
|
|
float vtln_high = -500;
|
|
|
|
bool debug_mel = false;
|
|
// htk_mode is a "hidden" config, it does not show up on command line.
|
|
// Enables more exact compatibility with HTK, for testing purposes. Affects
|
|
// mel-energy flooring and reproduces a bug in HTK.
|
|
bool htk_mode = false;
|
|
|
|
std::string ToString() const {
|
|
std::ostringstream os;
|
|
os << "MelBanksOptions(";
|
|
os << "num_bins=" << num_bins << ", ";
|
|
os << "low_freq=" << low_freq << ", ";
|
|
os << "high_freq=" << high_freq << ", ";
|
|
os << "vtln_low=" << vtln_low << ", ";
|
|
os << "vtln_high=" << vtln_high << ", ";
|
|
os << "debug_mel=" << (debug_mel ? "True" : "False") << ", ";
|
|
os << "htk_mode=" << (htk_mode ? "True" : "False") << ")";
|
|
return os.str();
|
|
}
|
|
};
|
|
|
|
std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts);
|
|
|
|
class MelBanks {
|
|
public:
|
|
static inline float InverseMelScale(float mel_freq) {
|
|
return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f);
|
|
}
|
|
|
|
static inline float MelScale(float freq) {
|
|
return 1127.0f * logf(1.0f + freq / 700.0f);
|
|
}
|
|
|
|
static float VtlnWarpFreq(
|
|
float vtln_low_cutoff,
|
|
float vtln_high_cutoff, // discontinuities in warp func
|
|
float low_freq,
|
|
float high_freq, // upper+lower frequency cutoffs in
|
|
// the mel computation
|
|
float vtln_warp_factor, float freq);
|
|
|
|
static float VtlnWarpMelFreq(float vtln_low_cutoff, float vtln_high_cutoff,
|
|
float low_freq, float high_freq,
|
|
float vtln_warp_factor, float mel_freq);
|
|
|
|
MelBanks(const MelBanksOptions &opts,
|
|
const FrameExtractionOptions &frame_opts, float vtln_warp_factor,
|
|
torch::Device device);
|
|
|
|
// Initialize with a 2-d weights matrix
|
|
//
|
|
// Note: This constructor is for Whisper. It does not initialize
|
|
// center_freqs_.
|
|
//
|
|
// @param weights Pointer to the start address of the matrix
|
|
// @param num_rows It equals to number of mel bins
|
|
// @param num_cols It equals to (number of fft bins)/2+1
|
|
MelBanks(const float *weights, int32_t num_rows, int32_t num_cols,
|
|
torch::Device device);
|
|
|
|
// CAUTION: we save a transposed version of bins_mat_, so return size(1) here
|
|
int32_t NumBins() const { return static_cast<int32_t>(bins_mat_.size(1)); }
|
|
|
|
// returns vector of central freq of each bin; needed by plp code.
|
|
const torch::Tensor &GetCenterFreqs() const { return center_freqs_; }
|
|
|
|
torch::Tensor Compute(const torch::Tensor &spectrum) const;
|
|
|
|
// for debug only
|
|
const torch::Tensor &GetBinsMat() const { return bins_mat_; }
|
|
|
|
private:
|
|
// A 2-D matrix. Its shape is NOT [num_bins, num_fft_bins]
|
|
// Its shape is [num_fft_bins, num_bins] for non-whisper.
|
|
// For whisper, its shape is [num_fft_bins/2+1, num_bins]
|
|
torch::Tensor bins_mat_;
|
|
|
|
// center frequencies of bins, numbered from 0 ... num_bins-1.
|
|
// Needed by GetCenterFreqs().
|
|
torch::Tensor center_freqs_; // It's always on CPU
|
|
|
|
bool debug_;
|
|
bool htk_mode_;
|
|
};
|
|
|
|
// Compute liftering coefficients (scaling on cepstral coeffs)
|
|
// coeffs are numbered slightly differently from HTK: the zeroth
|
|
// index is C0, which is not affected.
|
|
//
|
|
// coeffs is a 1-D float tensor
|
|
void ComputeLifterCoeffs(float Q, torch::Tensor *coeffs);
|
|
|
|
void GetEqualLoudnessVector(const MelBanks &mel_banks, torch::Tensor *ans);
|
|
|
|
/* Compute LP coefficients from autocorrelation coefficients.
|
|
*
|
|
* @param [in] autocorr_in A 2-D tensor. Each row is a frame. Its number of
|
|
* columns is lpc_order + 1
|
|
* @param [out] lpc_coeffs A 2-D tensor. On return, it has as many rows as the
|
|
* input tensor. Its number of columns is lpc_order.
|
|
*
|
|
* @return Returns log energy of residual in a 1-D tensor. It has as many
|
|
* elements as the number of rows in `autocorr_in`.
|
|
*/
|
|
torch::Tensor ComputeLpc(const torch::Tensor &autocorr_in,
|
|
torch::Tensor *lpc_coeffs);
|
|
|
|
/*
|
|
* @param [in] lpc It is the output argument `lpc_coeffs` in ComputeLpc().
|
|
*/
|
|
torch::Tensor Lpc2Cepstrum(const torch::Tensor &lpc);
|
|
|
|
} // namespace kaldifeat
|
|
|
|
#endif // KALDIFEAT_CSRC_MEL_COMPUTATIONS_H_
|