mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-10 10:32:16 +00:00
Refactoring.
This commit is contained in:
parent
e930dc176f
commit
ded479fc10
@ -42,30 +42,10 @@ torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
|
||||
torch::clamp_min(strided_input.pow(2).sum(1), kEps).log();
|
||||
}
|
||||
|
||||
if (frame_opts.preemph_coeff != 0.0f) {
|
||||
KALDIFEAT_ASSERT(frame_opts.preemph_coeff >= 0.0f &&
|
||||
frame_opts.preemph_coeff <= 1.0f);
|
||||
if (frame_opts.preemph_coeff != 0.0f)
|
||||
Preemphasize(frame_opts.preemph_coeff, &strided_input);
|
||||
|
||||
// right = strided_input[:, 1:]
|
||||
torch::Tensor right = strided_input.index(
|
||||
{"...", torch::indexing::Slice(1, torch::indexing::None,
|
||||
torch::indexing::None)});
|
||||
|
||||
// current = strided_input[:, 0:-1]
|
||||
torch::Tensor current =
|
||||
strided_input.index({"...", torch::indexing::Slice(0, -1, 1)});
|
||||
|
||||
// strided_input[1:, :] =
|
||||
// strided_input[:, 1:] - preemph_coeff * strided_input[:, 0:-1]
|
||||
strided_input.index(
|
||||
{"...", torch::indexing::Slice(1, torch::indexing::None,
|
||||
torch::indexing::None)}) =
|
||||
right - frame_opts.preemph_coeff * current;
|
||||
|
||||
strided_input.index({"...", 0}) *= 1 - frame_opts.preemph_coeff;
|
||||
}
|
||||
|
||||
strided_input = feature_window_function_.Apply(strided_input);
|
||||
feature_window_function_.Apply(&strided_input);
|
||||
|
||||
int32_t padding = frame_opts.PaddedWindowSize() - strided_input.sizes()[1];
|
||||
|
||||
|
@ -71,8 +71,12 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
|
||||
// note spectrum is in magnitude, not power, because of `abs()`
|
||||
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
|
||||
|
||||
// remove the last column, i.e., the highest fft bin
|
||||
spectrum = spectrum.index(
|
||||
{"...", torch::indexing::Slice(0, -1, torch::indexing::None)});
|
||||
|
||||
// Use power instead of magnitude if requested.
|
||||
if (opts_.use_power) spectrum = spectrum.pow(2);
|
||||
if (opts_.use_power) spectrum.pow_(2);
|
||||
|
||||
#if 0
|
||||
int32_t mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0);
|
||||
|
@ -51,10 +51,10 @@ FeatureWindowFunction::FeatureWindowFunction(
|
||||
window = window.unsqueeze(0);
|
||||
}
|
||||
|
||||
torch::Tensor FeatureWindowFunction::Apply(const torch::Tensor &input) const {
|
||||
KALDIFEAT_ASSERT(input.dim() == 2);
|
||||
KALDIFEAT_ASSERT(input.sizes()[1] == window.sizes()[1]);
|
||||
return input * window;
|
||||
void FeatureWindowFunction::Apply(torch::Tensor *wave) const {
|
||||
KALDIFEAT_ASSERT(wave->dim() == 2);
|
||||
KALDIFEAT_ASSERT(wave->sizes()[1] == wave->sizes()[1]);
|
||||
wave->mul_(window);
|
||||
}
|
||||
|
||||
static int64_t FirstSampleOfFrame(int32_t frame,
|
||||
@ -137,4 +137,26 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value) {
|
||||
return wave + rand_gauss * dither_value;
|
||||
}
|
||||
|
||||
void Preemphasize(float preemph_coeff, torch::Tensor *wave) {
|
||||
if (preemph_coeff == 0.0f) return;
|
||||
|
||||
KALDIFEAT_ASSERT(preemph_coeff >= 0.0f && preemph_coeff <= 1.0f);
|
||||
|
||||
// right = wave[:, 1:]
|
||||
torch::Tensor right =
|
||||
wave->index({"...", torch::indexing::Slice(1, torch::indexing::None,
|
||||
torch::indexing::None)});
|
||||
|
||||
// current = wave[:, 0:-1]
|
||||
torch::Tensor current = wave->index(
|
||||
{"...", torch::indexing::Slice(0, -1, torch::indexing::None)});
|
||||
|
||||
// wave[1:, :] = wave[:, 1:] - preemph_coeff * wave[:, 0:-1]
|
||||
wave->index({"...", torch::indexing::Slice(1, torch::indexing::None,
|
||||
torch::indexing::None)}) =
|
||||
right - preemph_coeff * current;
|
||||
|
||||
wave->index({"...", 0}) *= 1 - preemph_coeff;
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
@ -59,7 +59,7 @@ class FeatureWindowFunction {
|
||||
public:
|
||||
FeatureWindowFunction() = default;
|
||||
explicit FeatureWindowFunction(const FrameExtractionOptions &opts);
|
||||
torch::Tensor Apply(const torch::Tensor &input) const;
|
||||
void Apply(torch::Tensor *wave) const;
|
||||
|
||||
private:
|
||||
torch::Tensor window;
|
||||
@ -90,6 +90,8 @@ torch::Tensor GetStrided(const torch::Tensor &wave,
|
||||
|
||||
torch::Tensor Dither(const torch::Tensor &wave, float dither_value);
|
||||
|
||||
void Preemphasize(float preemph_coeff, torch::Tensor *wave);
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
||||
#endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_
|
||||
|
@ -131,9 +131,7 @@ MelBanks::MelBanks(const MelBanksOptions &opts,
|
||||
<< " and vtln-high " << vtln_high << ", versus "
|
||||
<< "low-freq " << low_freq << " and high-freq " << high_freq;
|
||||
|
||||
// TODO(fangjun): remove the last column of the power spectrum
|
||||
// and set the number of columns to num_fft_bins instead of num_fft_bins + 1
|
||||
bins_mat_ = torch::zeros({num_bins, num_fft_bins + 1}, torch::kFloat);
|
||||
bins_mat_ = torch::zeros({num_bins, num_fft_bins}, torch::kFloat);
|
||||
int32_t stride = bins_mat_.strides()[0];
|
||||
|
||||
for (int32_t bin = 0; bin < num_bins; ++bin) {
|
||||
@ -177,11 +175,12 @@ MelBanks::MelBanks(const MelBanksOptions &opts,
|
||||
}
|
||||
|
||||
if (debug_) KALDIFEAT_LOG << bins_mat_;
|
||||
|
||||
bins_mat_.t_();
|
||||
}
|
||||
|
||||
torch::Tensor MelBanks::Compute(const torch::Tensor &spectrum) const {
|
||||
// TODO(fangjun): save a transposed version of `bins_mat_`.
|
||||
return torch::mm(spectrum, bins_mat_.t());
|
||||
return torch::mm(spectrum, bins_mat_);
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
@ -13,7 +13,6 @@ PYBIND11_MODULE(_kaldifeat, m) {
|
||||
m.doc() = "Python wrapper for kaldifeat";
|
||||
|
||||
m.def("test", [](const torch::Tensor &tensor) -> torch::Tensor {
|
||||
std::cout << "size: " << tensor.sizes() << "\n";
|
||||
FbankOptions fbank_opts;
|
||||
fbank_opts.frame_opts.dither = 0.0f;
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user