Merge pull request #3 from csukuangfj/mfcc

Add MFCC features.
This commit is contained in:
Fangjun Kuang 2021-07-17 18:13:35 +08:00 committed by GitHub
commit 10c9d75919
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 1147 additions and 453 deletions

2
.gitignore vendored
View File

@ -3,3 +3,5 @@ build*/
*.egg-info*/
dist/
__pycache__/
test-1hour.wav
path.sh

View File

@ -2,19 +2,15 @@
set(kaldifeat_srcs
feature-fbank.cc
feature-mfcc.cc
feature-window.cc
matrix-functions.cc
mel-computations.cc
)
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
target_link_libraries(kaldifeat_core PUBLIC ${TORCH_LIBRARIES})
# PYTHON_INCLUDE_DIRS is set by pybind11
target_include_directories(kaldifeat_core PUBLIC ${PYTHON_INCLUDE_DIRS})
# PYTHON_LIBRARY is set by pybind11
target_link_libraries(kaldifeat_core PUBLIC ${PYTHON_LIBRARY})
add_executable(test_kaldifeat test_kaldifeat.cc)
target_link_libraries(test_kaldifeat PRIVATE kaldifeat_core)

View File

@ -13,11 +13,8 @@
#include "kaldifeat/csrc/feature-common.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/mel-computations.h"
#include "pybind11/pybind11.h"
#include "torch/torch.h"
namespace py = pybind11;
namespace kaldifeat {
struct FbankOptions {
@ -42,19 +39,9 @@ struct FbankOptions {
// analysis, else magnitude.
bool use_power = true;
torch::Device device;
torch::Device device{"cpu"};
FbankOptions() : device("cpu") { mel_opts.num_bins = 23; }
// Get/Set methods are for implementing properties in Python
py::object GetDevice() const {
py::object ans = py::module_::import("torch").attr("device");
return ans(device.str());
}
void SetDevice(py::object obj) {
std::string s = static_cast<py::str>(obj);
device = torch::Device(s);
}
FbankOptions() { mel_opts.num_bins = 23; }
std::string ToString() const {
std::ostringstream os;

View File

@ -0,0 +1,150 @@
// kaldifeat/csrc/feature-mfcc.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-mfcc.cc
#include "kaldifeat/csrc/feature-mfcc.h"
#include "kaldifeat/csrc/matrix-functions.h"
namespace kaldifeat {
std::ostream &operator<<(std::ostream &os, const MfccOptions &opts) {
os << opts.ToString();
return os;
}
MfccComputer::MfccComputer(const MfccOptions &opts) : opts_(opts) {
int32_t num_bins = opts.mel_opts.num_bins;
if (opts.num_ceps > num_bins) {
KALDIFEAT_ERR << "num-ceps cannot be larger than num-mel-bins."
<< " It should be smaller or equal. You provided num-ceps: "
<< opts.num_ceps << " and num-mel-bins: " << num_bins;
}
torch::Tensor dct_matrix = torch::empty({num_bins, num_bins}, torch::kFloat);
ComputeDctMatrix(&dct_matrix);
// Note that we include zeroth dct in either case. If using the
// energy we replace this with the energy. This means a different
// ordering of features than HTK.
using namespace torch::indexing; // It imports: Slice, None
// dct_matrix[:opts.num_cepts, :]
torch::Tensor dct_rows =
dct_matrix.index({Slice(0, opts.num_ceps, None), "..."});
dct_matrix_ = dct_rows.clone().t().to(opts.device);
if (opts.cepstral_lifter != 0.0) {
lifter_coeffs_ = torch::empty({1, opts.num_ceps}, torch::kFloat32);
ComputeLifterCoeffs(opts.cepstral_lifter, &lifter_coeffs_);
lifter_coeffs_ = lifter_coeffs_.to(opts.device);
}
if (opts.energy_floor > 0.0) log_energy_floor_ = logf(opts.energy_floor);
// We'll definitely need the filterbanks info for VTLN warping factor 1.0.
// [note: this call caches it.]
GetMelBanks(1.0);
}
const MelBanks *MfccComputer::GetMelBanks(float vtln_warp) {
MelBanks *this_mel_banks = nullptr;
// std::map<float, MelBanks *>::iterator iter = mel_banks_.find(vtln_warp);
auto iter = mel_banks_.find(vtln_warp);
if (iter == mel_banks_.end()) {
this_mel_banks =
new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp, opts_.device);
mel_banks_[vtln_warp] = this_mel_banks;
} else {
this_mel_banks = iter->second;
}
return this_mel_banks;
}
MfccComputer::~MfccComputer() {
for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter)
delete iter->second;
}
// ans.shape [signal_frame.size(0), this->Dim()]
torch::Tensor MfccComputer::Compute(torch::Tensor signal_raw_log_energy,
float vtln_warp,
const torch::Tensor &signal_frame) {
const MelBanks &mel_banks = *(GetMelBanks(vtln_warp));
KALDIFEAT_ASSERT(signal_frame.dim() == 2);
KALDIFEAT_ASSERT(signal_frame.size(1) == opts_.frame_opts.PaddedWindowSize());
// torch.finfo(torch.float32).eps
constexpr float kEps = 1.1920928955078125e-07f;
// Compute energy after window function (not the raw one).
if (opts_.use_energy && !opts_.raw_energy) {
signal_raw_log_energy =
torch::clamp_min(signal_frame.pow(2).sum(1), kEps).log();
}
// note spectrum is in magnitude, not power, because of `abs()`
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
// remove the last column, i.e., the highest fft bin
spectrum = spectrum.index(
{"...", torch::indexing::Slice(0, -1, torch::indexing::None)});
// Use power instead of magnitude
spectrum.pow_(2);
torch::Tensor mel_energies = mel_banks.Compute(spectrum);
// Avoid log of zero (which should be prevented anyway by dithering).
mel_energies = torch::clamp_min(mel_energies, kEps).log();
torch::Tensor features = torch::mm(mel_energies, dct_matrix_);
if (opts_.cepstral_lifter != 0.0) {
features = torch::mul(features, lifter_coeffs_);
}
if (opts_.use_energy) {
if (opts_.energy_floor > 0.0f) {
signal_raw_log_energy =
torch::clamp_min(signal_raw_log_energy, log_energy_floor_);
}
// column 0 is replaced by signal_raw_log_energy
//
// features[:, 0] = signal_raw_log_energy
//
features.index({"...", 0}) = signal_raw_log_energy;
}
if (opts_.htk_compat) {
// energy = features[:, 0]
// features[:, :-1] = features[:, 1:]
// features[:, -1] = energy *sqrt(2)
//
// shift left, so the original 0th column
// becomes the last column;
// the original first column becomes the 0th column
features = torch::roll(features, -1, 1);
if (!opts_.use_energy) {
// TODO(fangjun): change the DCT matrix so that we don't need
// to do an extra multiplication here.
//
// scale on C0 (actually removing a scale
// we previously added that's part of one common definition of
// the cosine transform.)
features.index({"...", -1}) *= M_SQRT2;
}
}
return features;
}
} // namespace kaldifeat

View File

@ -0,0 +1,116 @@
// kaldifeat/csrc/feature-mfcc.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-mfcc.h
#ifndef KALDIFEAT_CSRC_FEATURE_MFCC_H_
#define KALDIFEAT_CSRC_FEATURE_MFCC_H_
#include "kaldifeat/csrc/feature-common.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/mel-computations.h"
#include "torch/torch.h"
namespace kaldifeat {
/// MfccOptions contains basic options for computing MFCC features.
// (this class is copied from kaldi)
struct MfccOptions {
FrameExtractionOptions frame_opts;
MelBanksOptions mel_opts;
// Number of cepstra in MFCC computation (including C0)
int32_t num_ceps = 13;
// Use energy (not C0) in MFCC computation
bool use_energy = true;
// Floor on energy (absolute, not relative) in MFCC
// computation. Only makes a difference if use_energy=true;
// only necessary if dither=0.0.
// Suggested values: 0.1 or 1.0
float energy_floor = 0.0;
// If true, compute energy before preemphasis and windowing
bool raw_energy = true;
// Constant that controls scaling of MFCCs
float cepstral_lifter = 22.0;
// If true, put energy or C0 last and use a factor of
// sqrt(2) on C0.
// Warning: not sufficient to get HTK compatible features
// (need to change other parameters)
bool htk_compat = false;
torch::Device device{"cpu"};
MfccOptions() { mel_opts.num_bins = 23; }
std::string ToString() const {
std::ostringstream os;
os << "frame_opts: \n";
os << frame_opts << "\n";
os << "\n";
os << "mel_opts: \n";
os << mel_opts << "\n";
os << "num_ceps: " << num_ceps << "\n";
os << "use_energy: " << use_energy << "\n";
os << "energy_floor: " << energy_floor << "\n";
os << "raw_energy: " << raw_energy << "\n";
os << "cepstral_lifter: " << cepstral_lifter << "\n";
os << "htk_compat: " << htk_compat << "\n";
os << "device: " << device << "\n";
return os.str();
}
};
std::ostream &operator<<(std::ostream &os, const MfccOptions &opts);
class MfccComputer {
public:
using Options = MfccOptions;
explicit MfccComputer(const MfccOptions &opts);
~MfccComputer();
MfccComputer &operator=(const MfccComputer &) = delete;
MfccComputer(const MfccComputer &) = delete;
int32_t Dim() const { return opts_.num_ceps; }
bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; }
const FrameExtractionOptions &GetFrameOptions() const {
return opts_.frame_opts;
}
const MfccOptions &GetOptions() const { return opts_; }
// signal_raw_log_energy is log_energy_pre_window, which is not empty
// iff NeedRawLogEnergy() returns true.
torch::Tensor Compute(torch::Tensor signal_raw_log_energy, float vtln_warp,
const torch::Tensor &signal_frame);
private:
const MelBanks *GetMelBanks(float vtln_warp);
MfccOptions opts_;
torch::Tensor lifter_coeffs_; // 1-D tensor
// Note we save a transposed version of dct_matrix_
// dct_matrix_.rows is num_mel_bins
// dct_matrix_.cols is num_ceps
torch::Tensor dct_matrix_; // matrix we right-multiply by to perform DCT.
float log_energy_floor_;
std::map<float, MelBanks *> mel_banks_; // float is VTLN coefficient.
};
using Mfcc = OfflineFeatureTpl<MfccComputer>;
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_FEATURE_MFCC_H_

View File

@ -0,0 +1,45 @@
// kaldifeat/csrc/matrix-functions.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/matrix/matrix-functions.cc
#include "kaldifeat/csrc/matrix-functions.h"
#include <cmath>
#include "kaldifeat/csrc/log.h"
namespace kaldifeat {
void ComputeDctMatrix(torch::Tensor *mat) {
KALDIFEAT_ASSERT(mat->dim() == 2);
int32_t num_rows = mat->size(0);
int32_t num_cols = mat->size(1);
KALDIFEAT_ASSERT(num_rows == num_cols);
KALDIFEAT_ASSERT(num_rows > 0);
int32_t stride = mat->stride(0);
// normalizer for X_0
float normalizer = std::sqrt(1.0f / num_cols);
// mat[0, :] = normalizer
mat->index({0, "..."}) = normalizer;
// normalizer for other elements
normalizer = std::sqrt(2.0f / num_cols);
float *data = mat->data_ptr<float>();
for (int32_t r = 1; r < num_rows; ++r) {
float *this_row = data + r * stride;
for (int32_t c = 0; c < num_cols; ++c) {
float v = std::cos(static_cast<double>(M_PI) / num_cols * (c + 0.5) * r);
this_row[c] = normalizer * v;
}
}
}
} // namespace kaldifeat

View File

@ -0,0 +1,26 @@
// kaldifeat/csrc/matrix-functions.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/matrix/matrix-functions.h
#ifndef KALDIFEAT_CSRC_MATRIX_FUNCTIONS_H_
#define KALDIFEAT_CSRC_MATRIX_FUNCTIONS_H_
#include "torch/torch.h"
namespace kaldifeat {
/// ComputeDctMatrix computes a matrix corresponding to the DCT, such that
/// M * v equals the DCT of vector v. M must be square at input.
/// This is the type = II DCT with normalization, corresponding to the
/// following equations, where x is the signal and X is the DCT:
/// X_0 = sqrt(1/N) \sum_{n = 0}^{N-1} x_n
/// X_k = sqrt(2/N) \sum_{n = 0}^{N-1} x_n cos( \pi/N (n + 1/2) k )
/// See also
/// https://docs.scipy.org/doc/scipy/reference/generated/scipy.fftpack.dct.html
void ComputeDctMatrix(torch::Tensor *M);
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_MATRIX_FUNCTIONS_H_

View File

@ -192,4 +192,15 @@ torch::Tensor MelBanks::Compute(const torch::Tensor &spectrum) const {
return torch::mm(spectrum, bins_mat_);
}
void ComputeLifterCoeffs(float Q, torch::Tensor *coeffs) {
// Compute liftering coefficients (scaling on cepstral coeffs)
// coeffs are numbered slightly differently from HTK: the zeroth
// index is C0, which is not affected.
float *data = coeffs->data_ptr<float>();
int32_t n = coeffs->numel();
for (int32_t i = 0; i < n; ++i) {
data[i] = 1.0 + 0.5 * Q * sin(M_PI * i / Q);
}
}
} // namespace kaldifeat

View File

@ -89,6 +89,13 @@ class MelBanks {
bool htk_mode_;
};
// Compute liftering coefficients (scaling on cepstral coeffs)
// coeffs are numbered slightly differently from HTK: the zeroth
// index is C0, which is not affected.
//
// coeffs is a 1-D float tensor
void ComputeLifterCoeffs(float Q, torch::Tensor *coeffs);
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_MEL_COMPUTATIONS_H_

View File

@ -1,6 +1,7 @@
add_definitions(-DTORCH_API_INCLUDE_EXTENSION_H)
pybind11_add_module(_kaldifeat
feature-fbank.cc
feature-mfcc.cc
feature-window.cc
kaldifeat.cc
mel-computations.cc

View File

@ -4,34 +4,51 @@
#include "kaldifeat/python/csrc/feature-fbank.h"
#include <string>
#include "kaldifeat/csrc/feature-fbank.h"
namespace kaldifeat {
void PybindFbankOptions(py::module &m) {
py::class_<FbankOptions>(m, "FbankOptions")
static void PybindFbankOptions(py::module &m) {
using PyClass = FbankOptions;
py::class_<PyClass>(m, "FbankOptions")
.def(py::init<>())
.def_readwrite("frame_opts", &FbankOptions::frame_opts)
.def_readwrite("mel_opts", &FbankOptions::mel_opts)
.def_readwrite("use_energy", &FbankOptions::use_energy)
.def_readwrite("energy_floor", &FbankOptions::energy_floor)
.def_readwrite("raw_energy", &FbankOptions::raw_energy)
.def_readwrite("htk_compat", &FbankOptions::htk_compat)
.def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank)
.def_readwrite("use_power", &FbankOptions::use_power)
.def_property("device", &FbankOptions::GetDevice,
&FbankOptions::SetDevice)
.def("__str__", [](const FbankOptions &self) -> std::string {
return self.ToString();
});
.def_readwrite("frame_opts", &PyClass::frame_opts)
.def_readwrite("mel_opts", &PyClass::mel_opts)
.def_readwrite("use_energy", &PyClass::use_energy)
.def_readwrite("energy_floor", &PyClass::energy_floor)
.def_readwrite("raw_energy", &PyClass::raw_energy)
.def_readwrite("htk_compat", &PyClass::htk_compat)
.def_readwrite("use_log_fbank", &PyClass::use_log_fbank)
.def_readwrite("use_power", &PyClass::use_power)
.def_property(
"device",
[](const PyClass &self) -> py::object {
py::object ans = py::module_::import("torch").attr("device");
return ans(self.device.str());
},
[](PyClass &self, py::object obj) -> void {
std::string s = static_cast<py::str>(obj);
self.device = torch::Device(s);
})
.def("__str__",
[](const PyClass &self) -> std::string { return self.ToString(); });
}
py::class_<Fbank>(m, "Fbank")
static void PybindFbank(py::module &m) {
using PyClass = Fbank;
py::class_<PyClass>(m, "Fbank")
.def(py::init<const FbankOptions &>(), py::arg("opts"))
.def("dim", &Fbank::Dim)
.def("options", &Fbank::GetOptions,
py::return_value_policy::reference_internal)
.def("compute_features", &Fbank::ComputeFeatures, py::arg("wave"),
.def("dim", &PyClass::Dim)
.def_property_readonly("options", &PyClass::GetOptions)
.def("compute_features", &PyClass::ComputeFeatures, py::arg("wave"),
py::arg("vtln_warp"));
}
void PybindFeatureFbank(py::module &m) {
PybindFbankOptions(m);
PybindFbank(m);
}
} // namespace kaldifeat

View File

@ -9,7 +9,7 @@
namespace kaldifeat {
void PybindFbankOptions(py::module &m);
void PybindFeatureFbank(py::module &m);
} // namespace kaldifeat

View File

@ -0,0 +1,52 @@
// kaldifeat/python/csrc/feature-mfcc.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#include "kaldifeat/python/csrc/feature-mfcc.h"
#include "kaldifeat/csrc/feature-mfcc.h"
namespace kaldifeat {
void PybindMfccOptions(py::module &m) {
using PyClass = MfccOptions;
py::class_<PyClass>(m, "MfccOptions")
.def(py::init<>())
.def_readwrite("frame_opts", &PyClass::frame_opts)
.def_readwrite("mel_opts", &PyClass::mel_opts)
.def_readwrite("num_ceps", &PyClass::num_ceps)
.def_readwrite("use_energy", &PyClass::use_energy)
.def_readwrite("energy_floor", &PyClass::energy_floor)
.def_readwrite("raw_energy", &PyClass::raw_energy)
.def_readwrite("cepstral_lifter", &PyClass::cepstral_lifter)
.def_readwrite("htk_compat", &PyClass::htk_compat)
.def_property(
"device",
[](const PyClass &self) -> py::object {
py::object ans = py::module_::import("torch").attr("device");
return ans(self.device.str());
},
[](PyClass &self, py::object obj) -> void {
std::string s = static_cast<py::str>(obj);
self.device = torch::Device(s);
})
.def("__str__",
[](const PyClass &self) -> std::string { return self.ToString(); });
}
static void PybindMfcc(py::module &m) {
using PyClass = Mfcc;
py::class_<PyClass>(m, "Mfcc")
.def(py::init<const MfccOptions &>(), py::arg("opts"))
.def("dim", &PyClass::Dim)
.def_property_readonly("options", &PyClass::GetOptions)
.def("compute_features", &PyClass::ComputeFeatures, py::arg("wave"),
py::arg("vtln_warp"));
}
void PybindFeatureMfcc(py::module &m) {
PybindMfccOptions(m);
PybindMfcc(m);
}
} // namespace kaldifeat

View File

@ -0,0 +1,16 @@
// kaldifeat/python/csrc/feature-mfcc.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#ifndef KALDIFEAT_PYTHON_CSRC_FEATURE_MFCC_H_
#define KALDIFEAT_PYTHON_CSRC_FEATURE_MFCC_H_
#include "kaldifeat/python/csrc/kaldifeat.h"
namespace kaldifeat {
void PybindFeatureMfcc(py::module &m);
} // namespace kaldifeat
#endif // KALDIFEAT_PYTHON_CSRC_FEATURE_MFCC_H_

View File

@ -8,7 +8,7 @@
namespace kaldifeat {
void PybindFrameExtractionOptions(py::module &m) {
static void PybindFrameExtractionOptions(py::module &m) {
py::class_<FrameExtractionOptions>(m, "FrameExtractionOptions")
.def(py::init<>())
.def_readwrite("samp_freq", &FrameExtractionOptions::samp_freq)
@ -41,4 +41,13 @@ void PybindFrameExtractionOptions(py::module &m) {
m.def("get_strided", &GetStrided, py::arg("wave"), py::arg("opts"));
}
void PybindFeatureWindow(py::module &m) {
PybindFrameExtractionOptions(m);
m.def("num_frames", &NumFrames, py::arg("num_samples"), py::arg("opts"),
py::arg("flush") = true);
m.def("get_strided", &GetStrided, py::arg("wave"), py::arg("opts"));
}
} // namespace kaldifeat

View File

@ -9,7 +9,7 @@
namespace kaldifeat {
void PybindFrameExtractionOptions(py::module &m);
void PybindFeatureWindow(py::module &m);
} // namespace kaldifeat

View File

@ -4,53 +4,22 @@
#include "kaldifeat/python/csrc/kaldifeat.h"
#include <chrono>
#include "kaldifeat/csrc/feature-fbank.h"
#include "kaldifeat/python/csrc/feature-fbank.h"
#include "kaldifeat/python/csrc/feature-mfcc.h"
#include "kaldifeat/python/csrc/feature-window.h"
#include "kaldifeat/python/csrc/mel-computations.h"
#include "torch/torch.h"
namespace kaldifeat {
static torch::Tensor Compute(const torch::Tensor &wave, Fbank *fbank) {
float vtln_warp = 1.0f;
torch::Tensor ans = fbank->ComputeFeatures(wave, vtln_warp);
return ans;
}
PYBIND11_MODULE(_kaldifeat, m) {
m.doc() = "Python wrapper for kaldifeat";
PybindFrameExtractionOptions(m);
PybindMelBanksOptions(m);
PybindFbankOptions(m);
m.def("compute_fbank_feats", &Compute, py::arg("wave"), py::arg("fbank"));
// It verifies that the reimplementation produces the same output
// as kaldi using default parameters with dither disabled.
m.def(
"_compute_with_elapsed_time", // for benchmark only
[](const torch::Tensor &wave,
Fbank *fbank) -> std::pair<torch::Tensor, double> {
std::chrono::steady_clock::time_point begin =
std::chrono::steady_clock::now();
torch::Tensor ans = Compute(wave, fbank);
std::chrono::steady_clock::time_point end =
std::chrono::steady_clock::now();
double elapsed_seconds =
std::chrono::duration_cast<std::chrono::microseconds>(end - begin)
.count() /
1000000.;
return std::make_pair(ans, elapsed_seconds);
},
py::arg("wave"), py::arg("fbank"));
PybindFeatureWindow(m);
PybindMelComputations(m);
PybindFeatureFbank(m);
PybindFeatureMfcc(m);
}
} // namespace kaldifeat

View File

@ -8,20 +8,22 @@
namespace kaldifeat {
void PybindMelBanksOptions(py::module &m) {
py::class_<MelBanksOptions>(m, "MelBanksOptions")
static void PybindMelBanksOptions(py::module &m) {
using PyClass = MelBanksOptions;
py::class_<PyClass>(m, "MelBanksOptions")
.def(py::init<>())
.def_readwrite("num_bins", &MelBanksOptions::num_bins)
.def_readwrite("low_freq", &MelBanksOptions::low_freq)
.def_readwrite("high_freq", &MelBanksOptions::high_freq)
.def_readwrite("vtln_low", &MelBanksOptions::vtln_low)
.def_readwrite("vtln_high", &MelBanksOptions::vtln_high)
.def_readwrite("debug_mel", &MelBanksOptions::debug_mel)
.def_readwrite("htk_mode", &MelBanksOptions::htk_mode)
.def("__str__", [](const MelBanksOptions &self) -> std::string {
return self.ToString();
});
.def_readwrite("num_bins", &PyClass::num_bins)
.def_readwrite("low_freq", &PyClass::low_freq)
.def_readwrite("high_freq", &PyClass::high_freq)
.def_readwrite("vtln_low", &PyClass::vtln_low)
.def_readwrite("vtln_high", &PyClass::vtln_high)
.def_readwrite("debug_mel", &PyClass::debug_mel)
.def_readwrite("htk_mode", &PyClass::htk_mode)
.def("__str__",
[](const PyClass &self) -> std::string { return self.ToString(); });
;
}
void PybindMelComputations(py::module &m) { PybindMelBanksOptions(m); }
} // namespace kaldifeat

View File

@ -9,7 +9,7 @@
namespace kaldifeat {
void PybindMelBanksOptions(py::module &m);
void PybindMelComputations(py::module &m);
} // namespace kaldifeat

View File

@ -1,4 +1,10 @@
import torch
from _kaldifeat import FbankOptions, FrameExtractionOptions, MelBanksOptions
from _kaldifeat import (
FbankOptions,
FrameExtractionOptions,
MelBanksOptions,
MfccOptions,
)
from .fbank import Fbank
from .mfcc import Mfcc

View File

@ -1,82 +1,12 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
from typing import List, Union
import _kaldifeat
import torch
import torch.nn as nn
from .offline_feature import OfflineFeature
class Fbank(nn.Module):
class Fbank(OfflineFeature):
def __init__(self, opts: _kaldifeat.FbankOptions):
super().__init__()
self.opts = opts
super().__init__(opts)
self.computer = _kaldifeat.Fbank(opts)
def forward(
self, waves: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Compute the fbank features of a single waveform or
a list of waveforms.
Args:
waves:
A single 1-D tensor or a list of 1-D tensors. Each tensor contains
audio samples of a soundfile. To get a result compatible with Kaldi,
you should scale the samples to [-32768, 32767] before calling this
function. Note: You are not required to scale them if you don't care
about the compatibility with Kaldi.
Returns:
Return a list of 2-D tensors containing the fbank features if the
input is a list of 1-D tensors. The returned list has as many elements
as the input list.
Return a single 2-D tensor if the input is a single tensor.
"""
if isinstance(waves, list):
is_list = True
else:
waves = [waves]
is_list = False
num_frames_per_wave = [
_kaldifeat.num_frames(w.numel(), self.opts.frame_opts)
for w in waves
]
strided = [self.convert_samples_to_frames(w) for w in waves]
strided = torch.cat(strided, dim=0)
features = self.compute(strided)
if is_list:
return list(features.split(num_frames_per_wave))
else:
return features
def compute(self, x: torch.Tensor) -> torch.Tensor:
"""Compute fbank features given a 2-D tensor containing
frames data. Each row is a frame of size frame_lens, specified
in the fbank options.
Args:
x:
A 2-D tensor.
Returns:
Return a 2-D tensor with as many rows as the input tensor. Its
number of columns is the number mel bins.
"""
features = _kaldifeat.compute_fbank_feats(x, self.computer)
return features
def convert_samples_to_frames(self, wave: torch.Tensor) -> torch.Tensor:
"""Convert a 1-D tensor containing audio samples to a 2-D
tensor where each row is a frame of samples of size frame length
specified in the fbank options.
Args:
waves:
A 1-D tensor.
Returns:
Return a 2-D tensor.
"""
return _kaldifeat.get_strided(wave, self.opts.frame_opts)

View File

@ -0,0 +1,12 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
import _kaldifeat
from .offline_feature import OfflineFeature
class Mfcc(OfflineFeature):
def __init__(self, opts: _kaldifeat.MfccOptions):
super().__init__(opts)
self.computer = _kaldifeat.Mfcc(opts)

View File

@ -0,0 +1,141 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
from typing import List, Optional, Union
import _kaldifeat
import torch
import torch.nn as nn
class OfflineFeature(nn.Module):
"""Offline feature is a base class of other feature computers,
e.g., Fbank, Mfcc.
This class has two fields:
(1) opts. It contains the options for the feature computer.
(2) computer. The actual feature computer. It should be
instantiated by subclasses.
"""
def __init__(self, opts):
super().__init__()
self.opts = opts
# self.computer is expected to be set by subclasses
self.computer = None
def forward(
self,
waves: Union[torch.Tensor, List[torch.Tensor]],
vtln_warp: float = 1.0,
chunk_size: Optional[int] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Compute the features of a single waveform or
a list of waveforms.
Args:
waves:
A single 1-D tensor or a list of 1-D tensors. Each tensor contains
audio samples of a sound file. To get a result compatible with
Kaldi, you should scale the samples to [-32768, 32767] before
calling this function. Note: You are not required to scale them if
you don't care about the compatibility with Kaldi.
vtln_warp
The VTLN warping factor that the user wants to be applied when
computing features for this utterance. Will normally be 1.0,
meaning no warping is to be done. The value will be ignored for
feature types that don't support VLTN, such as spectrogram features.
chunk_size:
It specifies the number of frames for each computation. If
If None, it compute features at once (requiring more memory for
long waves) If not None, each computation takes this number of
frames (requiring less memory)
Returns:
Return a list of 2-D tensors containing the features if the
input is a list of 1-D tensors. The returned list has as many elements
as the input list.
Return a single 2-D tensor if the input is a single tensor.
"""
if isinstance(waves, list):
is_list = True
else:
waves = [waves]
is_list = False
num_frames_per_wave = [
_kaldifeat.num_frames(w.numel(), self.opts.frame_opts)
for w in waves
]
strided = [self.convert_samples_to_frames(w) for w in waves]
strided = torch.cat(strided, dim=0)
features = self.compute(strided, vtln_warp)
if is_list:
return list(features.split(num_frames_per_wave))
else:
return features
def compute(
self,
x: torch.Tensor,
vtln_warp: float = 1.0,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
"""Compute features given a 2-D tensor containing
frames data. Each row is a frame of size frame_lens, specified
in the options.
Args:
x:
A 2-D tensor.
vtln_warp
The VTLN warping factor that the user wants to be applied when
computing features for this utterance. Will normally be 1.0,
meaning no warping is to be done. The value will be ignored for
feature types that don't support VLTN, such as spectrogram features.
chunk_size:
It specifies the number of frames for each computation. If
If None, it compute features at once (requiring more memory for
long waves) If not None, each computation takes this number of
frames (requiring less memory)
Returns:
Return a 2-D tensor with as many rows as the input tensor. Its
number of columns is the number mel bins.
"""
assert x.ndim == 2
if chunk_size is None:
features = self.computer.compute_features(x, vtln_warp)
else:
assert chunk_size > 0
num_chunks = x.size(0) // chunk_size
end = 0
features = []
for i in range(num_chunks):
start = i * chunk_size
end = start + chunk_size
this_chunk = self.computer.compute_features(
x[start:end], vtln_warp
)
features.append(this_chunk)
if end < x.size(0):
last_chunk = self.compute_features(x[end:], vtln_warp)
features.append(last_chunk)
features = torch.cat(features, dim=0)
return features
def convert_samples_to_frames(self, wave: torch.Tensor) -> torch.Tensor:
"""Convert a 1-D tensor containing audio samples to a 2-D
tensor where each row is a frame of samples of size frame length
specified in the options.
Args:
waves:
A 1-D tensor.
Returns:
Return a 2-D tensor.
"""
return _kaldifeat.get_strided(wave, self.opts.frame_opts)

View File

View File

@ -25,6 +25,14 @@ if [ ! -f test.txt ]; then
compute-fbank-feats --dither=0 scp:test.scp ark,t:test.txt
fi
if [ ! -f test-mfcc.txt ]; then
compute-mfcc-feats --dither=0 scp:test.scp ark,t:test-mfcc.txt
fi
if [ ! -f test-mfcc-no-snip-edges.txt ]; then
compute-mfcc-feats --dither=0 --snip-edges=0 scp:test.scp ark,t:test-mfcc-no-snip-edges.txt
fi
if [ ! -f test-htk.txt ]; then
compute-fbank-feats --dither=0 --use-energy=1 --htk-compat=1 scp:test.scp ark,t:test-htk.txt
fi

View File

@ -0,0 +1,121 @@
1 [
25.39608 32.35185 -2.321556 14.60079 -13.74259 -4.045824 -24.44216 -15.79832 -25.61224 -15.68476 -16.97382 -7.824823 -5.879378
25.39381 45.95031 39.74594 24.35997 -0.8152001 -26.2968 -49.12233 -58.09226 -59.16555 -43.86792 -35.18478 -16.45622 -6.582274
25.38237 46.29809 39.69444 21.26501 -0.1167132 -30.20249 -48.83283 -62.01665 -55.13354 -47.37197 -29.98325 -16.55868 -3.574477
25.38654 45.92929 38.62181 19.25543 -2.401595 -31.4079 -50.66693 -60.24544 -56.83351 -46.73343 -31.91438 -13.66359 -4.147025
25.40916 46.29382 37.4265 20.59778 -5.06803 -28.80847 -51.90322 -56.63943 -57.86185 -43.94829 -28.04593 -11.67387 -1.56284
25.38297 46.47021 36.70466 18.70617 -6.900266 -34.04674 -51.83358 -60.35369 -54.66861 -44.32346 -23.42499 -8.679977 2.353952
25.39825 47.23348 35.89182 16.63462 -9.518853 -39.8504 -54.48053 -59.99181 -56.41674 -36.75331 -21.40067 -3.050015 6.079391
25.3968 47.3536 35.66127 14.08512 -14.63117 -42.1903 -59.47116 -62.91116 -53.13708 -35.10442 -16.158 -0.9198647 9.721549
25.3984 46.35926 34.55193 10.56108 -18.3789 -48.74673 -64.26622 -66.57558 -53.05719 -35.04904 -13.73361 2.403883 9.671058
25.38439 45.9486 34.18124 9.057709 -21.17389 -51.51536 -67.24743 -68.28139 -52.69724 -29.56226 -11.40866 4.705441 15.23537
25.40594 46.87861 34.65144 11.24935 -19.8245 -53.85929 -65.52903 -67.52113 -47.5673 -24.9378 -5.569555 11.10318 17.56986
25.39048 47.57718 34.22607 10.6608 -22.98505 -54.48465 -69.44428 -66.70899 -46.99065 -21.9698 -1.225232 9.322051 17.84232
25.3787 48.74944 33.28561 11.38392 -26.49508 -54.00847 -69.92937 -63.2187 -44.14838 -16.61138 2.719506 13.61085 15.68057
25.38585 47.17855 32.05386 5.280269 -29.57172 -59.32993 -71.95599 -65.94704 -45.48616 -17.20598 2.877878 13.4305 14.53257
25.39545 47.243 28.74062 2.290034 -32.83585 -62.3103 -73.15955 -67.27847 -44.69419 -18.05867 2.833896 18.73741 14.05055
25.40651 46.75592 29.52716 2.523693 -32.55991 -59.46822 -70.77213 -62.0563 -41.70202 -12.51747 9.536063 23.47834 24.67166
25.39396 46.68447 30.45101 0.5390255 -33.79927 -57.84798 -71.88931 -56.88164 -36.46875 -8.955982 12.32823 27.9718 25.87831
25.40022 46.45648 27.73865 -2.336577 -38.14998 -65.26423 -71.98856 -56.81524 -33.61908 -7.026244 14.16988 24.44519 23.83253
25.3949 45.99439 26.28573 -3.37527 -42.39796 -68.02455 -71.32982 -52.95689 -27.52928 -0.7912635 17.24487 26.69702 25.87051
25.40088 44.00927 25.87407 -6.363201 -42.76957 -73.92364 -67.43925 -52.57967 -20.44759 7.178041 21.33786 30.62585 27.13247
25.40474 44.24985 24.99637 -7.239056 -41.48176 -70.10802 -65.85403 -47.34629 -12.99007 13.20371 31.48185 33.96679 34.62817
25.39125 45.24401 24.2002 -8.070871 -43.40356 -70.26653 -66.76379 -44.15194 -10.50779 19.83112 28.44634 33.59096 30.22577
25.4066 43.51714 20.934 -15.18678 -53.31227 -75.91238 -72.67536 -46.27955 -11.89614 16.4112 25.45789 27.98687 19.24098
25.3951 45.63427 18.89521 -16.51617 -55.7245 -76.75443 -67.9464 -41.95996 -4.399398 22.20214 33.45767 30.37081 18.19485
25.39434 45.23126 18.64206 -19.19501 -56.65408 -76.82211 -65.10683 -34.32439 1.652178 30.4379 43.23981 33.30499 22.91569
25.39326 45.0127 19.00955 -21.57034 -58.73341 -78.3952 -63.41209 -29.54198 3.416465 36.02323 40.08428 36.15363 18.0722
25.39042 46.19297 17.52902 -20.8206 -65.88988 -77.08572 -62.16529 -25.13289 7.733559 34.61935 40.95137 30.76114 14.85412
25.39459 44.5594 17.77966 -25.04762 -68.03204 -82.85863 -60.69528 -22.475 10.59829 34.55471 36.88003 27.12749 12.73354
25.38486 44.00508 14.8887 -24.82657 -68.32502 -82.30861 -60.15085 -17.06987 18.70377 39.06653 38.03824 30.82622 12.41097
25.39576 43.71902 13.96611 -25.66086 -67.90189 -78.50242 -57.25552 -9.926379 25.96967 42.47167 41.01326 29.76805 12.07698
25.39695 43.41653 12.08988 -28.40393 -69.91819 -80.28224 -52.10577 -5.703701 29.85741 42.48271 33.83108 23.77441 4.621999
25.38689 41.68964 7.535811 -34.86737 -72.20117 -78.49009 -50.62822 -5.51798 29.58163 45.11713 35.3207 14.67932 -4.980578
25.39787 41.81786 8.359 -34.30838 -68.08477 -69.56466 -39.19672 4.961438 40.72778 57.27301 44.86645 19.0232 -4.365663
25.40326 41.66633 4.406162 -38.66424 -78.00182 -70.34608 -39.17017 2.660123 37.61974 50.6637 37.0939 8.616677 -12.67094
25.39464 41.7363 4.227285 -42.76331 -81.96279 -72.10927 -34.43838 10.05296 38.76222 48.2467 30.30499 3.992683 -19.28808
25.3835 41.11358 3.70572 -44.86536 -81.20176 -73.55413 -28.183 23.29612 46.17789 46.38068 30.24077 2.892463 -20.81282
25.39136 41.57816 3.81811 -45.5806 -79.3111 -70.29636 -17.1561 34.84626 55.73431 49.21229 30.18771 5.324198 -20.78599
25.39826 40.60241 -0.9265096 -50.03929 -84.95367 -73.53457 -20.45146 32.95836 52.69015 38.5402 23.32296 -9.152432 -25.56478
25.39433 40.82907 -0.5061941 -48.07018 -84.67371 -67.86473 -15.12541 42.61532 56.66839 45.71279 19.99749 -5.834647 -27.19608
25.39488 40.08175 -2.762471 -51.78362 -81.68471 -61.87555 -8.259188 42.69786 63.23629 50.00811 13.29985 -11.91144 -29.50841
25.39559 38.60545 -6.221917 -57.09323 -86.38364 -58.30085 -7.215154 37.04668 60.9631 43.06159 0.123965 -24.72754 -30.31331
25.38846 38.00803 -5.501757 -56.85784 -82.68424 -50.10187 -0.7668005 42.69961 62.05183 41.20802 -1.993352 -30.22561 -34.91791
25.39359 38.44771 -8.265057 -54.67931 -80.41744 -43.87169 14.42547 54.14865 60.40449 38.37543 4.480012 -33.40733 -34.97999
25.39786 35.19547 -12.29784 -62.75888 -85.39848 -50.94773 14.00066 49.63785 50.21563 23.85303 -6.808732 -42.55332 -40.28651
25.40098 36.47634 -10.52399 -60.02825 -78.02267 -41.74789 24.67215 61.73796 55.9189 29.94274 -5.743024 -36.19523 -31.93987
25.39558 35.54524 -13.13855 -63.11581 -80.99752 -35.54281 30.09668 66.401 50.23518 20.67985 -14.00711 -38.86665 -32.63693
25.39386 34.49712 -17.12528 -71.18398 -85.69121 -35.43599 25.84463 64.70671 46.57549 5.361976 -28.57582 -42.74088 -27.05177
25.39646 36.58266 -16.92528 -71.32783 -78.29313 -21.86354 38.17763 74.43158 60.02375 7.823904 -30.07853 -28.32218 -8.374956
25.39031 35.7993 -19.34796 -73.25931 -78.22402 -19.34886 43.6849 72.14301 51.0216 2.114614 -43.17448 -33.47993 -6.816935
25.39644 34.5329 -20.69224 -72.8614 -79.76266 -15.1011 48.02476 68.49402 44.33848 -0.2977941 -42.14186 -41.20908 2.603081
25.39023 33.62829 -22.30286 -72.10425 -73.52156 -7.060349 59.05154 72.16718 41.47753 -2.558425 -40.5256 -34.35424 13.5494
25.39391 30.22026 -27.18811 -77.48886 -76.1098 -10.05072 53.91814 60.34884 25.63873 -20.7681 -51.0108 -40.48099 12.48811
25.39349 31.81778 -25.61624 -72.39314 -67.259 5.034772 62.47023 67.34049 23.269 -19.30169 -45.74679 -27.11697 25.2069
25.38898 32.00583 -27.77481 -75.70822 -65.59901 9.742023 64.46254 64.51627 16.95127 -32.54502 -46.44199 -23.64906 24.77834
25.39424 31.30311 -28.39444 -79.45353 -58.97827 14.87803 65.79998 64.9374 16.43708 -42.95216 -40.69402 -10.36181 28.52364
25.3911 29.32449 -32.28419 -80.05022 -58.95622 22.17528 72.72678 61.39109 9.213846 -44.4451 -47.88344 2.248398 40.09723
25.39556 28.96247 -35.33858 -81.61469 -58.17877 28.65544 77.47958 55.24949 0.5992745 -45.21356 -47.33035 12.65332 46.42564
25.38706 28.78739 -36.64643 -80.91756 -52.4418 38.08194 81.43094 55.29769 -5.915738 -46.13671 -42.09949 20.64983 52.86951
25.39788 27.56331 -39.11371 -83.97359 -50.41508 38.79966 80.93089 46.31154 -13.89618 -51.27308 -37.04369 23.21174 50.792
25.3968 26.19263 -42.08035 -84.68729 -45.17972 42.09727 79.16079 39.44464 -24.57571 -53.62149 -30.66094 29.2725 50.32141
25.38883 24.48307 -44.83307 -86.77184 -41.25999 42.28427 70.86903 30.42861 -38.70833 -58.88421 -21.19395 28.93901 46.24177
25.38903 25.17598 -42.14454 -81.06771 -30.9208 56.28703 78.06828 32.3428 -33.96818 -52.31734 -4.164658 48.5117 46.72888
25.3899 22.27415 -47.12965 -81.50928 -33.96595 56.36651 68.66307 19.68891 -44.66286 -61.32234 -4.15141 55.11304 28.45146
25.39197 21.70904 -47.20979 -79.77702 -24.4659 65.35997 74.59259 16.09454 -43.09087 -53.9151 9.232623 66.43803 24.4652
25.38938 20.7649 -51.35069 -81.114 -21.37276 65.7199 70.42231 4.628209 -53.03323 -51.66714 12.12255 63.55159 14.46632
25.3979 19.97068 -55.8799 -84.11887 -15.078 72.27465 68.29669 -2.468351 -60.28827 -42.79148 23.19747 60.835 6.945527
25.39607 21.82284 -55.96371 -80.77085 -5.249561 79.81233 69.13963 -6.043231 -60.17236 -29.52646 37.24368 63.51701 5.083388
25.39347 17.64588 -59.96587 -87.09677 -3.594641 76.85334 58.25056 -19.02715 -67.09344 -28.17631 38.09797 52.51156 -10.77667
25.39333 16.98159 -57.03935 -79.00468 4.423308 87.96435 57.2234 -18.0171 -55.0177 -16.05751 65.53757 55.49708 -13.74628
25.39682 15.83234 -58.39169 -76.12865 7.49015 82.94011 47.87926 -33.64468 -62.70134 -12.67456 68.6539 31.38465 -31.07693
25.39563 14.32628 -60.45678 -70.6311 12.10744 85.65196 45.10929 -36.24592 -62.32951 1.713272 72.08761 32.79416 -35.06253
25.39661 13.70812 -63.97197 -68.00484 20.01553 89.11328 40.22752 -39.62369 -54.52652 17.92656 76.70629 30.38804 -35.57701
25.39479 10.83091 -71.481 -73.84635 20.39659 80.00975 25.78455 -56.51125 -64.30695 18.37884 62.30385 12.22257 -48.71682
25.39425 12.78613 -68.40564 -68.83294 37.16087 90.14788 25.0805 -54.19313 -50.42888 38.71355 69.12119 11.70307 -40.25348
25.39023 12.02825 -70.28168 -69.79598 40.78049 89.60378 7.185571 -60.63461 -49.70446 38.91372 69.08297 -13.96448 -40.58236
25.3956 4.760854 -76.64693 -73.26047 35.93605 81.95063 -1.233862 -73.90819 -46.06424 46.97755 55.24388 -32.62456 -33.64398
25.39037 8.905246 -69.64375 -57.18198 50.67647 88.90368 2.843832 -64.58672 -22.77782 70.62299 58.77176 -34.86448 -12.72868
25.39725 7.940741 -71.91322 -51.36573 52.9388 84.9756 -4.101084 -67.40937 -12.37923 75.80219 50.50125 -42.1508 -11.41325
25.39277 3.620197 -78.85619 -55.09002 49.78735 68.9934 -17.75496 -81.43687 -12.13576 65.99275 31.10003 -56.2431 -21.62771
25.39759 2.435049 -77.94907 -52.0281 60.4553 66.13054 -24.33409 -76.27236 -0.7728864 73.98727 23.16094 -54.99735 -10.40479
25.39358 4.273045 -77.24577 -45.15873 72.87798 72.08435 -30.73451 -61.16426 10.21801 87.62292 19.73735 -50.93239 10.27398
25.38981 2.751296 -80.87396 -43.95321 72.19351 67.36891 -50.80131 -64.84616 12.39576 83.41122 -7.147185 -53.67238 12.15022
25.39294 -2.704355 -86.64239 -44.16715 69.00771 61.65969 -61.09535 -70.25542 29.14181 75.42706 -28.96672 -47.29532 17.79565
25.38969 1.415106 -79.73869 -28.73986 78.14098 62.44078 -57.32119 -56.76227 52.57478 74.33283 -31.22603 -34.96442 35.70661
25.39302 -2.079781 -79.92892 -24.442 81.62313 55.30362 -55.7385 -48.35144 66.92049 68.7485 -34.55513 -27.71027 47.65649
25.39511 -4.022318 -80.47609 -21.52332 85.63417 46.43619 -56.02192 -36.65096 74.20319 68.73482 -40.77957 -19.08163 58.90694
25.3896 -5.414461 -81.49599 -23.30378 84.1657 24.62001 -70.31181 -42.96472 62.80655 48.83625 -68.57776 -23.1779 43.46037
25.39262 -7.253727 -83.84801 -17.98841 90.5005 22.48438 -72.22514 -26.68364 74.52739 44.08413 -67.57236 -6.135344 44.65955
25.39382 -12.79493 -90.36889 -19.10461 86.39917 13.03741 -87.50249 -24.15998 81.51659 15.34293 -70.03619 3.049052 27.96535
25.39132 -7.71529 -83.19132 -4.152222 94.4517 16.77529 -89.31504 -3.447665 92.53352 5.591094 -61.65958 22.10788 27.75796
25.39603 -10.74664 -84.79973 -0.2955267 94.26595 6.483277 -86.08201 6.502524 93.05707 -7.934523 -60.08295 36.04617 24.57903
25.39602 -11.27997 -83.69574 6.480924 91.98006 -2.142771 -80.26712 17.66297 91.16392 -19.549 -56.49924 49.83942 21.58446
25.39135 -10.88186 -82.4868 14.21179 93.71255 -8.670269 -68.07482 31.69443 92.93135 -22.26974 -50.52993 71.31429 15.23693
25.39513 -12.35472 -80.73472 16.21279 87.89179 -20.62118 -75.19752 37.67764 77.45379 -38.46849 -51.3957 68.68193 4.069196
25.39181 -13.94713 -79.58489 22.90905 92.27758 -22.07222 -69.42229 57.6824 79.08192 -45.69419 -25.44657 68.77338 8.495295
25.39159 -15.79215 -82.57479 29.49712 91.81357 -28.34865 -69.14436 74.7967 70.16924 -54.18642 -7.831672 68.10078 0.7469958
25.39181 -17.39479 -81.50367 35.05952 90.83553 -39.12706 -66.75391 84.08272 58.46498 -66.06717 7.52849 62.01939 -6.908999
25.39583 -19.17727 -83.71353 37.67658 80.9612 -51.81308 -64.88435 82.18725 40.56393 -79.58513 15.13444 52.05294 -22.78725
25.39167 -20.16072 -78.93819 43.79243 81.6329 -54.67167 -42.00085 87.08875 41.21338 -79.07898 42.15937 50.41256 -20.7732
25.39645 -24.6506 -81.85218 40.83022 69.69142 -69.60789 -38.12477 77.7979 20.76274 -89.65546 46.05492 35.40798 -32.58874
25.39249 -25.25383 -76.05853 45.73799 70.4415 -70.32792 -28.46495 92.16783 5.400962 -74.91256 57.39972 29.2258 -22.31723
25.39421 -25.93147 -74.3922 48.21676 64.91846 -80.30769 -26.48508 95.00976 -17.85125 -70.26611 59.61102 11.04532 -18.38121
25.39548 -26.64165 -72.49001 59.15154 65.25706 -80.5808 -9.466668 104.5628 -25.10479 -54.97909 74.83305 5.088544 -8.382775
25.39491 -30.36923 -74.1627 67.20228 61.70013 -81.2366 7.522305 109.6248 -33.62559 -39.28767 88.20802 -4.186057 -0.4004028
25.39287 -32.73807 -74.41891 70.67566 53.52892 -80.76713 21.84896 106.0331 -41.35748 -24.02505 97.24953 -14.24201 7.01977
25.39086 -29.71002 -67.84211 73.39272 42.87555 -83.16228 27.91243 91.65209 -57.973 -15.09424 88.89481 -30.29982 0.9986133
25.39401 -31.93818 -68.83884 70.44012 29.15322 -90.20501 28.18499 77.21148 -80.70105 -3.166327 68.61609 -41.38475 -8.232524
25.39615 -31.35117 -64.13421 80.55035 29.70428 -79.07188 46.20612 84.83353 -81.80692 30.49987 64.08166 -25.1185 -2.201945
25.39224 -32.09728 -62.95401 86.4822 20.64005 -77.93591 54.82684 78.84743 -90.40573 47.9959 52.74029 -23.11835 1.972355
25.39235 -37.22515 -64.71143 87.3827 9.682838 -85.38517 63.45883 59.93484 -97.20854 54.17868 38.32671 -34.41825 7.559117
25.39638 -37.2728 -60.42402 95.62625 4.738393 -76.32872 78.98089 50.96959 -89.24931 74.13763 31.54243 -33.56698 26.18805
25.39244 -40.40992 -58.6135 95.057 -2.064055 -67.38006 89.76022 36.85495 -79.17003 90.1308 18.34283 -27.10337 39.54053
25.3928 -42.18478 -57.45359 85.07165 -16.65524 -76.95797 83.91347 2.704683 -85.15604 82.49094 -14.48867 -35.29811 25.33562
25.39375 -39.11324 -46.76147 92.30137 -12.4394 -62.43464 103.0772 -0.5291833 -62.03246 98.6534 -11.36484 -16.16133 36.57934
25.39448 -45.75553 -49.88998 86.66156 -27.98986 -66.95346 97.29249 -23.35108 -64.96902 90.48683 -33.67971 -20.86947 26.78923
25.39384 -47.07386 -45.67285 98.4855 -27.88428 -47.48809 115.9556 -22.83762 -37.94284 108.9418 -31.63764 5.350817 35.58423
25.39325 -47.20466 -43.30311 95.80333 -46.68718 -50.55157 107.1271 -51.93818 -36.80504 91.52704 -54.96243 3.622933 25.19625
25.39371 -49.77121 -39.8161 99.29076 -47.84101 -31.96148 111.8069 -56.79999 -9.866325 92.28719 -57.74104 25.95889 27.18584
25.39429 -52.17422 -37.2928 92.46472 -50.7436 -25.99195 111.952 -73.76846 17.54422 73.91215 -55.1922 31.55746 28.6031
25.39458 -34.44326 -28.86453 26.94478 -34.14916 -4.071999 18.66267 -26.84388 0.3021353 16.91182 -25.71729 8.220589 7.145334 ]

View File

@ -0,0 +1,119 @@
1 [
25.38532 46.38532 40.43646 26.04891 0.3232674 -23.76116 -47.58815 -55.77805 -56.82245 -43.12204 -33.96529 -15.93318 -4.92479
25.40668 45.93211 39.33534 20.82029 -1.113101 -30.38894 -49.96596 -62.04239 -57.14521 -47.23344 -32.91168 -18.48427 -5.089862
25.39494 46.07357 38.62945 19.25349 -2.265106 -32.01031 -50.18764 -61.27365 -56.84636 -47.66944 -32.12716 -14.62166 -4.821646
25.38965 45.94759 37.9682 19.80231 -4.0096 -29.4664 -51.8117 -57.52616 -57.92469 -44.71363 -29.21408 -12.42252 -2.346232
25.40345 46.62432 36.93261 20.22779 -6.637779 -31.15936 -51.155 -58.34569 -54.31585 -44.24779 -23.55167 -8.863897 1.932559
25.39414 47.0273 36.15705 16.58885 -8.128201 -39.45121 -53.62976 -60.34694 -56.38583 -38.94326 -22.10081 -4.889903 4.995643
25.41012 46.91393 34.75902 14.06857 -14.90907 -43.02711 -59.8685 -63.81323 -56.25241 -36.58431 -19.64196 -2.910627 7.436204
25.40563 46.02902 33.75841 9.815521 -18.97714 -49.22369 -65.1647 -67.90424 -54.6832 -38.24193 -15.96374 -0.4098501 7.494248
25.39347 45.74887 34.34372 9.187881 -21.36539 -50.77408 -67.06487 -68.61117 -53.06862 -31.04547 -12.18845 3.406743 13.96347
25.40889 46.77818 34.49423 10.93285 -19.64441 -53.10184 -66.13144 -66.88577 -49.03279 -25.34286 -6.76951 10.02704 17.72726
25.38499 47.12933 34.29369 10.56643 -22.28681 -55.12388 -68.67861 -68.13386 -47.54093 -23.46774 -3.028766 9.482397 16.8621
25.38442 48.70028 33.62667 11.80223 -25.65564 -53.64765 -69.23545 -63.91164 -43.98318 -17.46281 2.005727 12.89727 16.93959
25.38152 47.63571 32.48622 7.139426 -28.57889 -57.95905 -71.18561 -64.88127 -45.10395 -16.69648 3.557112 13.37648 14.39043
25.39383 46.94392 29.59289 2.017951 -32.50276 -61.90096 -73.95693 -67.6615 -45.44095 -18.55774 1.837688 16.74773 13.36994
25.40385 47.12545 28.84329 2.734576 -32.42643 -60.80044 -70.77826 -63.81884 -42.39577 -13.89173 7.771653 22.42514 22.48918
25.39873 46.49928 30.75968 1.002508 -33.17535 -57.35464 -71.59306 -57.37133 -37.99292 -9.59599 12.28184 27.56538 26.19395
25.40309 46.1997 27.98814 -1.953473 -38.02605 -64.15872 -72.99059 -58.27158 -35.4799 -8.580102 12.25856 24.28761 23.02804
25.38262 46.9842 27.27348 -2.016821 -40.08006 -65.59283 -70.50599 -51.54396 -27.56731 -1.209845 18.71492 26.53801 27.16768
25.39624 43.92025 25.47549 -6.607616 -43.6745 -74.41853 -68.53377 -54.78749 -23.31723 4.109671 18.09908 29.79642 24.47102
25.39575 44.05944 24.97052 -7.567508 -42.25732 -71.459 -67.41235 -48.84031 -15.07865 10.47071 29.33895 31.42866 33.07372
25.3924 44.85686 24.52086 -7.132763 -42.49384 -70.18423 -66.54608 -44.94352 -11.19412 18.67526 28.82976 33.72599 31.6594
25.38689 45.19783 23.86852 -11.37798 -47.38741 -70.82645 -67.34332 -41.82202 -8.254937 21.04181 29.64841 33.13927 24.62312
25.39572 44.31588 17.71649 -17.72659 -58.227 -78.43457 -72.16118 -45.4472 -9.346466 18.46966 28.31182 27.43386 16.1719
25.38815 45.35692 18.12014 -18.78281 -56.98325 -77.38441 -66.07838 -37.17476 0.8247437 26.81772 41.68108 31.41829 22.09493
25.39475 45.28372 19.76315 -20.83242 -55.95263 -77.14462 -61.23661 -28.61336 5.090013 38.48487 43.34359 39.2795 21.6187
25.39246 46.5613 18.25419 -19.54837 -63.68708 -75.5659 -61.67481 -24.63826 8.006861 36.0433 41.97263 33.14376 16.46943
25.39709 43.62916 15.8978 -27.10357 -70.68687 -84.75415 -64.51408 -27.46575 5.749561 30.79235 34.87248 24.47454 10.111
25.40476 44.53081 16.33788 -23.93591 -67.07205 -80.35097 -58.71431 -15.81992 19.03357 41.20985 40.68231 32.75274 15.57526
25.39227 43.71073 14.47089 -24.98154 -68.02865 -80.99158 -58.16587 -11.73452 25.01885 40.31492 39.22186 30.33737 12.84578
25.40465 42.27491 10.20313 -30.28809 -72.64389 -82.17522 -57.41895 -11.11493 24.0666 39.67345 32.96381 21.9115 2.886812
25.39269 41.65544 7.56587 -34.77799 -73.19456 -81.34157 -52.71831 -7.772728 27.66308 42.23445 32.30576 15.58152 -4.544934
25.40318 41.73953 8.699425 -34.04586 -67.95364 -70.96313 -40.60702 3.855407 40.10815 56.46727 44.87423 19.60626 -2.302107
25.39934 41.57477 5.091603 -37.54679 -75.64105 -70.10496 -40.19965 3.010088 36.95604 52.39346 38.24314 11.2449 -12.00633
25.39792 41.89267 4.785582 -41.50554 -80.99605 -70.77592 -34.63141 7.563346 39.55383 49.0803 32.80286 4.96469 -17.11308
25.38512 41.53328 3.819975 -44.26881 -81.09616 -73.24087 -30.16476 21.10081 43.9517 47.26962 30.95039 2.939907 -19.41306
25.39313 41.23391 3.738014 -45.56654 -80.16782 -71.62791 -19.88947 31.50323 53.3121 48.27503 28.9826 5.061873 -22.37913
25.39145 41.75009 2.177709 -47.12666 -81.33154 -71.12909 -16.19259 36.67249 55.31452 42.28677 27.32597 -3.512858 -22.33639
25.40082 41.06936 -0.6284239 -48.22186 -84.35454 -68.31895 -17.40651 39.94257 56.1668 44.18793 20.63585 -7.49818 -27.29285
25.39538 40.66327 -1.764934 -49.68883 -81.01283 -62.07708 -7.587211 46.25889 64.5416 53.44604 19.23967 -5.863674 -25.35489
25.39835 39.47585 -4.357427 -54.48608 -84.04228 -57.2988 -5.514464 39.56948 63.78515 45.67173 4.828966 -20.74485 -29.23663
25.3992 38.71435 -4.570549 -57.16107 -82.79095 -50.64725 -2.243274 41.17722 63.37231 43.62814 -2.56049 -28.37142 -31.75221
25.39301 38.46305 -7.096981 -54.25875 -80.5594 -45.11859 11.32789 51.13285 60.42904 39.24098 3.029696 -33.19812 -36.88346
25.38788 37.45215 -9.117052 -57.84779 -80.83936 -45.29604 18.14567 54.84932 56.31918 30.36193 -0.6390674 -37.45828 -36.51476
25.39198 36.1208 -10.67175 -60.95683 -80.21278 -44.96912 20.51619 57.68769 52.78998 26.93523 -7.832117 -39.61742 -35.60844
25.39025 36.40623 -12.04668 -60.3553 -77.62048 -34.51344 31.57949 68.11986 55.99387 26.94683 -8.569519 -35.13813 -28.99591
25.38977 35.91476 -14.14716 -66.91611 -82.72645 -32.77011 29.26745 68.1011 48.31393 10.71127 -23.13274 -40.55798 -28.392
25.39284 36.54054 -16.12559 -70.62806 -79.60706 -25.32921 35.12775 71.56234 56.57858 6.576417 -30.53827 -33.27609 -13.34043
25.38806 36.50312 -17.82539 -72.42554 -76.25536 -18.52526 43.82878 73.96272 55.58988 2.957646 -39.91483 -30.59143 -7.019096
25.39465 34.24603 -21.97831 -74.45589 -82.35056 -18.74894 43.74449 66.38205 42.42561 -2.104149 -45.67908 -42.69179 -2.577034
25.39323 33.55885 -21.96626 -72.39411 -74.90198 -8.426734 57.39182 72.87286 43.9062 1.08643 -38.50952 -33.45496 13.40423
25.39714 28.97935 -29.57898 -80.39554 -80.34734 -14.65935 50.51971 57.85884 25.43532 -22.0516 -52.47909 -43.26852 9.135543
25.38697 30.57876 -26.84743 -75.47865 -71.44372 -2.267755 55.87255 61.77605 18.57638 -23.35756 -52.27132 -34.99677 17.82942
25.39441 30.80536 -28.87331 -76.07663 -68.48928 8.553014 64.66566 64.47206 18.5043 -27.50596 -45.58059 -24.7173 27.06434
25.39055 30.89975 -28.87548 -80.79969 -62.29104 11.54676 63.10345 62.96517 14.60226 -42.07924 -42.97483 -16.24887 24.83545
25.39395 29.82885 -31.10229 -79.01083 -58.02502 21.2737 72.30254 63.44043 12.08938 -43.65895 -45.07901 0.3763475 38.42318
25.38966 27.01946 -38.17745 -85.37549 -64.12541 20.63371 69.38239 50.04358 -4.756351 -52.31731 -54.67303 2.962495 38.45331
25.39335 28.38519 -36.95552 -81.74364 -53.94786 35.79619 80.91443 55.62282 -4.373586 -45.00974 -42.9394 20.79662 51.94242
25.39293 26.57629 -40.62536 -86.22507 -55.04028 33.77809 76.24027 42.74701 -16.93307 -56.41526 -44.5168 15.96091 46.69937
25.39348 27.45793 -39.84925 -82.93864 -45.02178 43.84418 81.40533 42.55453 -20.41894 -50.90053 -31.04898 29.92536 52.12495
25.39263 23.15161 -47.10671 -89.712 -46.07175 37.93438 69.54885 29.11407 -37.84161 -60.93122 -27.73634 25.47708 44.98783
25.39305 25.607 -41.61295 -82.01257 -31.59099 53.97821 76.72733 33.20645 -35.78217 -52.53894 -5.644599 43.89986 48.76618
25.39201 21.56894 -48.11875 -84.18115 -37.25055 52.42075 67.03419 19.45025 -45.36713 -63.81252 -10.00664 49.67364 29.063
25.39616 20.91827 -47.94791 -82.21919 -28.82557 60.12504 69.43927 11.72181 -47.67692 -61.49944 3.008572 60.98768 21.06763
25.39701 21.26366 -50.45441 -80.77496 -22.92478 65.86505 70.91223 7.953492 -49.77809 -52.99021 9.201393 63.99749 16.09211
25.39065 21.35323 -52.22644 -80.18671 -12.8774 74.53942 73.42382 3.673998 -54.60684 -40.22329 26.29227 66.89787 13.4744
25.3955 21.62635 -56.51224 -82.01313 -8.142892 77.62218 68.73994 -6.19886 -60.15164 -33.72015 32.60169 61.24866 4.403368
25.38783 21.05478 -55.58424 -81.81695 0.5686536 81.15867 65.31736 -12.16074 -63.20033 -24.72629 40.87747 58.55075 -3.440407
25.39352 17.97288 -56.70938 -80.27528 2.817797 86.49545 56.30198 -18.44834 -57.97326 -20.73165 57.0295 54.98438 -14.85742
25.39195 15.92449 -58.46394 -77.895 4.769008 81.22737 47.02503 -34.30138 -65.31602 -18.40672 64.66216 30.96558 -31.02048
25.39283 14.5195 -59.76467 -70.74762 12.49147 87.48908 49.58151 -31.11487 -58.19484 3.319366 76.18122 38.96074 -28.88755
25.39704 13.99183 -62.6895 -68.70069 17.06756 86.64992 38.84517 -42.08774 -59.44818 11.25803 72.1925 26.07138 -39.29481
25.38908 15.497 -64.88111 -66.45351 25.50888 87.64637 34.62891 -47.06266 -58.24336 21.65615 70.16651 22.08021 -42.21913
25.39367 14.03758 -66.53736 -67.58192 36.16739 88.90174 26.96653 -54.35226 -52.04314 36.13499 68.2683 14.12985 -41.49809
25.39642 12.42562 -69.82998 -70.02515 39.60614 89.95936 11.42941 -59.08032 -50.56667 38.14076 69.12003 -8.965389 -42.51231
25.39367 9.420222 -70.65002 -67.22116 43.81015 91.21423 8.143895 -62.83793 -39.74917 51.74773 67.21016 -20.10557 -28.20558
25.39505 7.64828 -71.25418 -59.85993 48.9677 90.50465 5.415679 -62.491 -22.48333 71.95388 63.77154 -28.66063 -11.10132
25.39614 7.598652 -72.18832 -54.87945 49.52846 83.12215 -5.464087 -70.61362 -20.34403 70.31898 48.09004 -45.71234 -16.51912
25.39404 7.321666 -73.77635 -49.6713 54.0638 75.83004 -11.01439 -75.34438 -7.732321 70.89206 38.82542 -50.56324 -16.10443
25.39353 5.020178 -74.14969 -47.4748 64.83239 74.46149 -14.46084 -68.61813 5.367599 81.67761 34.40161 -45.47061 -5.726971
25.39484 5.448654 -75.20386 -44.9261 70.86498 71.91998 -27.32423 -65.71491 8.267384 84.16631 22.11889 -52.55405 4.966821
25.3952 3.330077 -79.94658 -44.86795 72.20807 67.819 -47.81216 -63.03672 9.537488 84.54942 -2.14252 -55.88456 10.48003
25.39439 -0.02191679 -83.02991 -40.85019 74.64166 68.41186 -52.83466 -62.34428 32.37473 85.31492 -17.59925 -40.52883 23.23387
25.39125 0.4396203 -81.06796 -30.3757 79.63792 67.92451 -52.36978 -53.92926 53.72873 81.73109 -22.3437 -30.56319 38.29645
25.3952 -2.830251 -81.48553 -26.72247 80.79588 58.96914 -53.07327 -47.65161 67.48645 75.05592 -29.16131 -25.1123 49.18814
25.39419 -2.401139 -78.90726 -22.14992 82.21037 45.20182 -60.55093 -45.12735 67.54078 61.42963 -46.68263 -29.02958 50.5612
25.39473 -3.594689 -79.59908 -21.65723 85.06188 29.82176 -67.49429 -41.63226 65.46536 52.99676 -63.33311 -23.68229 46.21498
25.39162 -3.813236 -80.05599 -17.21558 91.80138 23.1885 -70.6835 -29.94617 70.93774 47.79065 -67.76472 -9.80926 45.3359
25.39325 -7.276777 -82.80712 -11.98845 94.6318 23.04904 -75.88926 -17.6189 87.44342 30.70049 -60.68114 8.555294 39.32879
25.39149 -8.223318 -83.52549 -4.309903 98.07117 22.65733 -81.96045 -0.5579842 98.17782 15.5908 -54.95921 26.04539 34.94342
25.39344 -10.93268 -84.58604 0.4153341 97.63563 14.5156 -81.54991 10.77081 100.1797 2.49443 -53.05777 38.98166 32.88707
25.39408 -12.27234 -84.07152 5.945063 95.93524 5.130111 -74.39716 23.08295 100.8175 -7.270844 -48.56831 56.26474 31.97482
25.39202 -9.680718 -81.99494 10.95169 88.96925 -12.7099 -77.19251 21.06593 85.71529 -29.59091 -60.09066 59.81875 11.19004
25.39589 -11.815 -80.54845 17.76547 92.42519 -13.52795 -67.99579 40.74923 87.21462 -28.38006 -48.05539 75.88336 10.0998
25.39645 -15.02196 -80.23888 21.20151 92.62309 -19.03099 -68.09854 55.26164 82.50536 -40.44894 -29.94642 70.53014 10.56429
25.39591 -14.34652 -81.74042 28.07563 91.86987 -27.54577 -70.99956 69.26332 71.35625 -54.1442 -13.99096 65.78969 -0.3682785
25.3912 -17.66055 -83.1585 30.01328 86.20444 -43.3409 -75.76914 73.60836 52.79985 -72.1168 -5.751298 54.17259 -12.94435
25.39535 -20.28679 -84.41975 39.47444 88.11517 -41.01907 -58.25925 92.83658 54.58632 -65.74653 21.26763 63.91584 -11.92479
25.39261 -20.54185 -81.35249 39.93372 78.32833 -58.50629 -52.01899 79.96899 36.21083 -85.58276 30.30122 45.40194 -26.21034
25.39389 -19.46382 -77.01328 45.04123 74.14634 -64.67137 -36.78105 80.81107 28.04484 -87.24133 48.22746 39.9678 -28.85943
25.39423 -21.98055 -73.25692 47.66246 72.47936 -68.12497 -27.57518 90.71562 11.64535 -76.76437 57.68767 34.75176 -23.98895
25.39533 -25.90566 -74.21096 49.04252 68.69678 -74.57565 -23.55081 98.90106 -9.592892 -67.52109 64.00034 19.93253 -14.47118
25.39554 -32.18652 -79.76717 49.32306 58.67518 -87.027 -20.57716 95.30236 -30.88301 -66.72587 64.54459 0.4255028 -17.11223
25.39308 -29.71075 -75.55727 60.20243 55.10045 -90.82842 -7.644746 97.81853 -43.59582 -54.64902 73.24262 -11.86323 -13.1221
25.3965 -28.92496 -72.39767 71.34161 54.50336 -84.22205 14.50714 102.513 -44.21877 -33.34446 90.77426 -19.34229 1.841854
25.39348 -28.26368 -67.61832 73.06441 44.22544 -84.67479 23.96183 92.19052 -57.11995 -20.05127 87.64636 -29.45557 -0.05139913
25.39616 -32.10293 -68.56171 71.93924 33.85649 -86.38581 30.50137 81.86202 -72.36772 -5.93773 77.45951 -37.8978 -3.239398
25.39317 -31.34605 -67.11718 73.36967 23.79452 -89.55775 33.82771 75.24729 -90.83019 13.93961 56.23328 -36.17545 -12.04314
25.39464 -32.70353 -64.34343 83.96381 22.28917 -79.05529 50.60429 79.82584 -90.11343 42.27896 53.32549 -25.01945 -2.318132
25.39489 -35.42332 -65.10469 85.09756 8.425458 -88.03746 56.32113 59.19263 -101.9395 47.20057 35.8171 -36.72979 1.052083
25.39444 -38.78215 -64.03951 89.22057 0.286116 -86.35546 68.35411 45.32402 -99.97911 60.54477 26.17678 -41.63462 14.52861
25.39328 -38.22108 -58.18196 94.61192 -2.139997 -72.95795 83.41586 36.69357 -86.31156 81.94895 17.3798 -33.9303 32.42916
25.39395 -40.01531 -54.92315 93.10029 -7.37498 -66.57444 93.28197 19.90489 -75.42856 93.55279 2.036242 -24.93197 37.27777
25.39198 -41.84994 -52.4922 85.68872 -18.33885 -72.29888 91.81085 -7.271101 -74.82742 87.63387 -20.06664 -28.13238 26.55455
25.39395 -40.37983 -44.82071 91.27369 -20.40925 -61.91277 103.0631 -12.95668 -59.67254 97.33856 -22.57121 -15.51652 34.38859
25.39408 -43.25077 -46.11867 92.14365 -34.0641 -60.45708 103.2876 -32.41691 -54.4558 95.39404 -41.73148 -8.88547 24.89328
25.393 -44.60904 -40.46962 100.6379 -38.18752 -45.22467 114.6389 -39.3225 -32.0847 101.8648 -43.39056 9.565659 33.91491
25.39316 -53.61244 -44.69142 96.91879 -47.70626 -34.96313 112.8361 -52.64176 -14.21924 96.32195 -55.02893 23.29743 29.13971
25.39214 -46.51054 -33.65408 96.5901 -49.29968 -26.76014 113.0089 -69.34328 11.96164 79.19379 -54.53493 33.12771 29.79057 ]

View File

@ -2,107 +2,136 @@
# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang)
import numpy as np
import soundfile as sf
from pathlib import Path
import torch
from utils import read_ark_txt, read_wave
import kaldifeat
def read_wave(filename) -> torch.Tensor:
"""Read a wave file and return it as a 1-D tensor.
Note:
You don't need to scale it to [-32768, 32767].
We use scaling here to follow the approach in Kaldi.
Args:
filename:
Filename of a sound file.
Returns:
Return a 1-D tensor containing audio samples.
"""
with sf.SoundFile(filename) as sf_desc:
sampling_rate = sf_desc.samplerate
assert sampling_rate == 16000
data = sf_desc.read(dtype=np.float32, always_2d=False)
data *= 32768
return torch.from_numpy(data)
cur_dir = Path(__file__).resolve().parent
def test_fbank():
device = torch.device("cpu")
# if torch.cuda.is_available():
# device = torch.device("cuda", 0)
def test_fbank_default():
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
wave0 = read_wave("test_data/test.wav")
wave1 = read_wave("test_data/test2.wav")
features = fbank(wave)
gt = read_ark_txt(cur_dir / "test_data/test.txt")
assert torch.allclose(features, gt, rtol=1e-1)
wave0 = wave0.to(device)
wave1 = wave1.to(device)
def test_fbank_htk():
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.use_energy = True
opts.htk_compat = True
fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
features = fbank(wave)
gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
assert torch.allclose(features, gt, rtol=1e-1)
def test_fbank_with_energy():
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.use_energy = True
fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
features = fbank(wave)
gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
assert torch.allclose(features, gt, rtol=1e-1)
def test_fbank_40_bins():
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.mel_opts.num_bins = 40
fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
features = fbank(wave)
gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
assert torch.allclose(features, gt, rtol=1e-1)
def test_fbank_40_bins_no_snip_edges():
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.mel_opts.num_bins = 40
opts.frame_opts.snip_edges = False
fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
features = fbank(wave)
gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
assert torch.allclose(features, gt, rtol=1e-1)
def test_fbank_chunk():
filename = cur_dir / "test_data/test-1hour.wav"
if filename.is_file() is False:
print(
f"Please execute {cur_dir}/test_data/run.sh "
f"to generate {filename} before running tis test"
)
return
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.device = device
opts.mel_opts.num_bins = 40
opts.frame_opts.snip_edges = False
fbank = kaldifeat.Fbank(opts)
wave = read_wave(filename)
# We can compute fbank features in batches
features = fbank([wave0, wave1])
assert isinstance(features, list), f"{type(features)}"
assert len(features) == 2
# You can use
#
# $ watch -n 0.2 free -m
#
# to view memory consumption
#
# 100 frames per chunk
features = fbank(wave, chunk_size=100 * 10)
print(features.shape)
def test_fbank_batch():
wave0 = read_wave(cur_dir / "test_data/test.wav")
wave1 = read_wave(cur_dir / "test_data/test2.wav")
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
fbank = kaldifeat.Fbank(opts)
features = fbank([wave0, wave1], chunk_size=10)
# We can also compute fbank features for a single wave
features0 = fbank(wave0)
features1 = fbank(wave1)
assert torch.allclose(features[0], features0)
assert torch.allclose(features[1], features1)
# To compute fbank features for only a specified frame
audio_frames = fbank.convert_samples_to_frames(wave0)
feature_frame_1 = fbank.compute(audio_frames[1:2])
feature_frame_10 = fbank.compute(audio_frames[10:11])
assert torch.allclose(features0[1], feature_frame_1)
assert torch.allclose(features0[10], feature_frame_10)
def test_benchmark():
# You have to run ./test_data/run.sh to generate test_data/test-1hour.wav
device = torch.device("cpu")
# device = torch.device('cuda:0')
wave = read_wave("test_data/test-1hour.wav").to(device)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.device = device
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
# 1 seconds has 100 frames
chunk_size = 100 * 10 # 10 seconds
audio_frames = fbank.convert_samples_to_frames(wave)
num_chunks = audio_frames.size(0) // chunk_size
features = []
for i in range(num_chunks):
start = i * chunk_size
end = start + chunk_size
this_chunk = fbank.compute(audio_frames[start:end])
features.append(this_chunk)
if end < audio_frames.size(0):
last_chunk = fbank.compute(audio_frames[end:])
features.append(last_chunk)
features = torch.cat(features, dim=0)
# watch -n 0.2 free -m
# features2 = fbank(wave)
# assert torch.allclose(features, features2)
if __name__ == "__main__":
test_fbank()
# test_benchmark()
test_fbank_default()
test_fbank_htk()
test_fbank_with_energy()
test_fbank_40_bins()
test_fbank_40_bins_no_snip_edges()
test_fbank_chunk()
test_fbank_batch()

View File

@ -1,199 +0,0 @@
#!/usr/bin/env python3
#
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
import sys
from pathlib import Path
from typing import List
cur_dir = Path(__file__).resolve().parent
kaldi_feat_dir = cur_dir.parent.parent.parent
sys.path.insert(0, f"{kaldi_feat_dir}/build/lib")
import _kaldifeat
import numpy as np
import soundfile as sf
import torch
def read_ark_txt(filename) -> torch.Tensor:
test_data_dir = cur_dir / "test_data"
filename = test_data_dir / filename
features = []
with open(filename) as f:
for line in f:
if "[" in line:
continue
line = line.strip("").split()
data = [float(d) for d in line if d != "]"]
features.append(data)
ans = torch.tensor(features)
return ans
def read_wave() -> torch.Tensor:
test_data_dir = cur_dir / "test_data"
filename = test_data_dir / "test.wav"
with sf.SoundFile(filename) as sf_desc:
sampling_rate = sf_desc.samplerate
assert sampling_rate == 16000
data = sf_desc.read(dtype=np.float32, always_2d=False)
data *= 32768
return torch.from_numpy(data)
def test_and_benchmark_default_parameters():
devices = [torch.device("cpu")]
if torch.cuda.is_available():
devices.append(torch.device("cuda", 0))
for device in devices:
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.device = device
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave().to(device)
ans, elapsed_seconds = _kaldifeat._compute_with_elapsed_time(
data, fbank
)
expected = read_ark_txt("test.txt")
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
print(f"elapsed seconds {device}:", elapsed_seconds)
def test_use_energy_htk_compat_true():
devices = [torch.device("cpu")]
if torch.cuda.is_available():
devices.append(torch.device("cuda", 0))
for device in devices:
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.device = device
fbank_opts.use_energy = True
fbank_opts.htk_compat = True
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave().to(device)
ans = _kaldifeat.compute_fbank_feats(data, fbank)
expected = read_ark_txt("test-htk.txt")
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
def test_use_energy_htk_compat_false():
devices = [torch.device("cpu")]
if torch.cuda.is_available():
devices.append(torch.device("cuda", 0))
for device in devices:
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.use_energy = True
fbank_opts.htk_compat = False
fbank_opts.device = device
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave().to(device)
ans = _kaldifeat.compute_fbank_feats(data, fbank)
expected = read_ark_txt("test-with-energy.txt")
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
def test_40_mel():
devices = [torch.device("cpu")]
if torch.cuda.is_available():
devices.append(torch.device("cuda", 0))
for device in devices:
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.mel_opts.num_bins = 40
fbank_opts.device = device
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave().to(device)
ans = _kaldifeat.compute_fbank_feats(data, fbank)
expected = read_ark_txt("test-40.txt")
assert torch.allclose(ans.cpu(), expected, rtol=1e-1)
def test_40_mel_no_snip_edges():
devices = [torch.device("cpu")]
if torch.cuda.is_available():
devices.append(torch.device("cuda", 0))
for device in devices:
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.snip_edges = False
fbank_opts.frame_opts.dither = 0
fbank_opts.mel_opts.num_bins = 40
fbank_opts.device = device
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave().to(device)
ans = _kaldifeat.compute_fbank_feats(data, fbank)
expected = read_ark_txt("test-40-no-snip-edges.txt")
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
def test_compute_batch():
devices = [torch.device("cpu")]
if torch.cuda.is_available():
devices.append(torch.device("cuda", 0))
for device in devices:
data1 = read_wave().to(device)
data2 = read_wave().to(device)
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.frame_opts.snip_edges = False
fbank_opts.device = device
fbank = _kaldifeat.Fbank(fbank_opts)
def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]:
num_frames = [
_kaldifeat.num_frames(w.numel(), fbank_opts.frame_opts)
for w in waves
]
strided = [
_kaldifeat.get_strided(w, fbank_opts.frame_opts) for w in waves
]
strided = torch.cat(strided, dim=0)
features = _kaldifeat.compute_fbank_feats(strided, fbank).split(
num_frames
)
return features
feature1, feature2 = impl([data1, data2])
assert torch.allclose(feature1, feature2)
def main():
test_and_benchmark_default_parameters()
test_use_energy_htk_compat_true()
test_use_energy_htk_compat_false()
test_40_mel()
test_40_mel_no_snip_edges()
test_compute_batch()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,43 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang)
from pathlib import Path
import torch
from utils import read_ark_txt, read_wave
import kaldifeat
cur_dir = Path(__file__).resolve().parent
def test_mfcc_default():
opts = kaldifeat.MfccOptions()
opts.frame_opts.dither = 0
mfcc = kaldifeat.Mfcc(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
features = mfcc(wave)
gt = read_ark_txt(cur_dir / "test_data/test-mfcc.txt")
assert torch.allclose(features, gt, rtol=1e-1)
def test_mfcc_no_snip_edges():
opts = kaldifeat.MfccOptions()
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
mfcc = kaldifeat.Mfcc(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
features = mfcc(wave)
gt = read_ark_txt(cur_dir / "test_data/test-mfcc-no-snip-edges.txt")
assert torch.allclose(features, gt, rtol=1e-1)
if __name__ == "__main__":
test_mfcc_default()
test_mfcc_no_snip_edges()

View File

@ -9,13 +9,15 @@ cur_dir = Path(__file__).resolve().parent
kaldi_feat_dir = cur_dir.parent.parent.parent
import torch
sys.path.insert(0, f"{kaldi_feat_dir}/build/lib")
import _kaldifeat
import kaldifeat
def test_frame_extraction_options():
opts = _kaldifeat.FrameExtractionOptions()
opts = kaldifeat.FrameExtractionOptions()
opts.samp_freq = 220500
opts.frame_shift_ms = 15
opts.frame_length_ms = 40
@ -30,7 +32,7 @@ def test_frame_extraction_options():
def test_mel_banks_options():
opts = _kaldifeat.MelBanksOptions()
opts = kaldifeat.MelBanksOptions()
opts.num_bins = 23
opts.low_freq = 21
opts.high_freq = 8000
@ -42,7 +44,7 @@ def test_mel_banks_options():
def test_fbank_options():
opts = _kaldifeat.FbankOptions()
opts = kaldifeat.FbankOptions()
frame_opts = opts.frame_opts
mel_opts = opts.mel_opts
@ -52,7 +54,41 @@ def test_fbank_options():
opts.use_energy = False
opts.use_log_fbank = True
opts.use_power = True
opts.device = "cuda:0"
opts.device = torch.device("cuda", 0)
frame_opts.blackman_coeff = 0.42
frame_opts.dither = 1
frame_opts.frame_length_ms = 25
frame_opts.frame_shift_ms = 10
frame_opts.preemph_coeff = 0.97
frame_opts.remove_dc_offset = True
frame_opts.round_to_power_of_two = True
frame_opts.samp_freq = 16000
frame_opts.snip_edges = True
frame_opts.window_type = "povey"
mel_opts.debug_mel = True
mel_opts.high_freq = 0
mel_opts.low_freq = 20
mel_opts.num_bins = 23
mel_opts.vtln_high = -500
mel_opts.vtln_low = 100
print(opts)
def test_mfcc_options():
opts = kaldifeat.MfccOptions()
frame_opts = opts.frame_opts
mel_opts = opts.mel_opts
opts.num_ceps = 10
opts.use_energy = False
opts.energy_floor = 0.0
opts.raw_energy = True
opts.cepstral_lifter = 22.0
opts.htk_compat = False
opts.device = torch.device("cpu")
frame_opts.blackman_coeff = 0.42
frame_opts.dither = 1
@ -79,6 +115,7 @@ def main():
test_frame_extraction_options()
test_mel_banks_options()
test_fbank_options()
test_mfcc_options()
if __name__ == "__main__":

View File

@ -0,0 +1,41 @@
# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang)
import numpy as np
import soundfile as sf
import torch
def read_wave(filename) -> torch.Tensor:
"""Read a wave file and return it as a 1-D tensor.
Note:
You don't need to scale it to [-32768, 32767].
We use scaling here to follow the approach in Kaldi.
Args:
filename:
Filename of a sound file.
Returns:
Return a 1-D tensor containing audio samples.
"""
with sf.SoundFile(filename) as sf_desc:
sampling_rate = sf_desc.samplerate
assert sampling_rate == 16000
data = sf_desc.read(dtype=np.float32, always_2d=False)
data *= 32768
return torch.from_numpy(data)
def read_ark_txt(filename) -> torch.Tensor:
# test_data_dir = cur_dir / "test_data"
# filename = test_data_dir / filename
features = []
with open(filename) as f:
for line in f:
if "[" in line:
continue
line = line.strip("").split()
data = [float(d) for d in line if d != "]"]
features.append(data)
ans = torch.tensor(features)
return ans