diff --git a/kaldifeat/csrc/feature-common-inl.h b/kaldifeat/csrc/feature-common-inl.h index e5fa801..03566bb 100644 --- a/kaldifeat/csrc/feature-common-inl.h +++ b/kaldifeat/csrc/feature-common-inl.h @@ -42,30 +42,10 @@ torch::Tensor OfflineFeatureTpl::ComputeFeatures(const torch::Tensor &wave, torch::clamp_min(strided_input.pow(2).sum(1), kEps).log(); } - if (frame_opts.preemph_coeff != 0.0f) { - KALDIFEAT_ASSERT(frame_opts.preemph_coeff >= 0.0f && - frame_opts.preemph_coeff <= 1.0f); + if (frame_opts.preemph_coeff != 0.0f) + Preemphasize(frame_opts.preemph_coeff, &strided_input); - // right = strided_input[:, 1:] - torch::Tensor right = strided_input.index( - {"...", torch::indexing::Slice(1, torch::indexing::None, - torch::indexing::None)}); - - // current = strided_input[:, 0:-1] - torch::Tensor current = - strided_input.index({"...", torch::indexing::Slice(0, -1, 1)}); - - // strided_input[1:, :] = - // strided_input[:, 1:] - preemph_coeff * strided_input[:, 0:-1] - strided_input.index( - {"...", torch::indexing::Slice(1, torch::indexing::None, - torch::indexing::None)}) = - right - frame_opts.preemph_coeff * current; - - strided_input.index({"...", 0}) *= 1 - frame_opts.preemph_coeff; - } - - strided_input = feature_window_function_.Apply(strided_input); + feature_window_function_.Apply(&strided_input); int32_t padding = frame_opts.PaddedWindowSize() - strided_input.sizes()[1]; diff --git a/kaldifeat/csrc/feature-fbank.cc b/kaldifeat/csrc/feature-fbank.cc index 6569eba..b21083c 100644 --- a/kaldifeat/csrc/feature-fbank.cc +++ b/kaldifeat/csrc/feature-fbank.cc @@ -71,8 +71,12 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy, // note spectrum is in magnitude, not power, because of `abs()` torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs(); + // remove the last column, i.e., the highest fft bin + spectrum = spectrum.index( + {"...", torch::indexing::Slice(0, -1, torch::indexing::None)}); + // Use power instead of magnitude if requested. - if (opts_.use_power) spectrum = spectrum.pow(2); + if (opts_.use_power) spectrum.pow_(2); #if 0 int32_t mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0); diff --git a/kaldifeat/csrc/feature-window.cc b/kaldifeat/csrc/feature-window.cc index 00ace47..4a511c5 100644 --- a/kaldifeat/csrc/feature-window.cc +++ b/kaldifeat/csrc/feature-window.cc @@ -51,10 +51,10 @@ FeatureWindowFunction::FeatureWindowFunction( window = window.unsqueeze(0); } -torch::Tensor FeatureWindowFunction::Apply(const torch::Tensor &input) const { - KALDIFEAT_ASSERT(input.dim() == 2); - KALDIFEAT_ASSERT(input.sizes()[1] == window.sizes()[1]); - return input * window; +void FeatureWindowFunction::Apply(torch::Tensor *wave) const { + KALDIFEAT_ASSERT(wave->dim() == 2); + KALDIFEAT_ASSERT(wave->sizes()[1] == wave->sizes()[1]); + wave->mul_(window); } static int64_t FirstSampleOfFrame(int32_t frame, @@ -137,4 +137,26 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value) { return wave + rand_gauss * dither_value; } +void Preemphasize(float preemph_coeff, torch::Tensor *wave) { + if (preemph_coeff == 0.0f) return; + + KALDIFEAT_ASSERT(preemph_coeff >= 0.0f && preemph_coeff <= 1.0f); + + // right = wave[:, 1:] + torch::Tensor right = + wave->index({"...", torch::indexing::Slice(1, torch::indexing::None, + torch::indexing::None)}); + + // current = wave[:, 0:-1] + torch::Tensor current = wave->index( + {"...", torch::indexing::Slice(0, -1, torch::indexing::None)}); + + // wave[1:, :] = wave[:, 1:] - preemph_coeff * wave[:, 0:-1] + wave->index({"...", torch::indexing::Slice(1, torch::indexing::None, + torch::indexing::None)}) = + right - preemph_coeff * current; + + wave->index({"...", 0}) *= 1 - preemph_coeff; +} + } // namespace kaldifeat diff --git a/kaldifeat/csrc/feature-window.h b/kaldifeat/csrc/feature-window.h index 586b43c..a4a3ed8 100644 --- a/kaldifeat/csrc/feature-window.h +++ b/kaldifeat/csrc/feature-window.h @@ -59,7 +59,7 @@ class FeatureWindowFunction { public: FeatureWindowFunction() = default; explicit FeatureWindowFunction(const FrameExtractionOptions &opts); - torch::Tensor Apply(const torch::Tensor &input) const; + void Apply(torch::Tensor *wave) const; private: torch::Tensor window; @@ -90,6 +90,8 @@ torch::Tensor GetStrided(const torch::Tensor &wave, torch::Tensor Dither(const torch::Tensor &wave, float dither_value); +void Preemphasize(float preemph_coeff, torch::Tensor *wave); + } // namespace kaldifeat #endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_ diff --git a/kaldifeat/csrc/mel-computations.cc b/kaldifeat/csrc/mel-computations.cc index 557811e..1c711a8 100644 --- a/kaldifeat/csrc/mel-computations.cc +++ b/kaldifeat/csrc/mel-computations.cc @@ -131,9 +131,7 @@ MelBanks::MelBanks(const MelBanksOptions &opts, << " and vtln-high " << vtln_high << ", versus " << "low-freq " << low_freq << " and high-freq " << high_freq; - // TODO(fangjun): remove the last column of the power spectrum - // and set the number of columns to num_fft_bins instead of num_fft_bins + 1 - bins_mat_ = torch::zeros({num_bins, num_fft_bins + 1}, torch::kFloat); + bins_mat_ = torch::zeros({num_bins, num_fft_bins}, torch::kFloat); int32_t stride = bins_mat_.strides()[0]; for (int32_t bin = 0; bin < num_bins; ++bin) { @@ -177,11 +175,12 @@ MelBanks::MelBanks(const MelBanksOptions &opts, } if (debug_) KALDIFEAT_LOG << bins_mat_; + + bins_mat_.t_(); } torch::Tensor MelBanks::Compute(const torch::Tensor &spectrum) const { - // TODO(fangjun): save a transposed version of `bins_mat_`. - return torch::mm(spectrum, bins_mat_.t()); + return torch::mm(spectrum, bins_mat_); } } // namespace kaldifeat diff --git a/kaldifeat/python/csrc/kaldifeat.cc b/kaldifeat/python/csrc/kaldifeat.cc index 2ebc982..b7c8ac9 100644 --- a/kaldifeat/python/csrc/kaldifeat.cc +++ b/kaldifeat/python/csrc/kaldifeat.cc @@ -13,7 +13,6 @@ PYBIND11_MODULE(_kaldifeat, m) { m.doc() = "Python wrapper for kaldifeat"; m.def("test", [](const torch::Tensor &tensor) -> torch::Tensor { - std::cout << "size: " << tensor.sizes() << "\n"; FbankOptions fbank_opts; fbank_opts.frame_opts.dither = 0.0f;