Refactoring.

This commit is contained in:
Fangjun Kuang 2021-02-27 00:01:00 +08:00
parent e930dc176f
commit ded479fc10
6 changed files with 41 additions and 35 deletions

View File

@ -42,30 +42,10 @@ torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
torch::clamp_min(strided_input.pow(2).sum(1), kEps).log(); torch::clamp_min(strided_input.pow(2).sum(1), kEps).log();
} }
if (frame_opts.preemph_coeff != 0.0f) { if (frame_opts.preemph_coeff != 0.0f)
KALDIFEAT_ASSERT(frame_opts.preemph_coeff >= 0.0f && Preemphasize(frame_opts.preemph_coeff, &strided_input);
frame_opts.preemph_coeff <= 1.0f);
// right = strided_input[:, 1:] feature_window_function_.Apply(&strided_input);
torch::Tensor right = strided_input.index(
{"...", torch::indexing::Slice(1, torch::indexing::None,
torch::indexing::None)});
// current = strided_input[:, 0:-1]
torch::Tensor current =
strided_input.index({"...", torch::indexing::Slice(0, -1, 1)});
// strided_input[1:, :] =
// strided_input[:, 1:] - preemph_coeff * strided_input[:, 0:-1]
strided_input.index(
{"...", torch::indexing::Slice(1, torch::indexing::None,
torch::indexing::None)}) =
right - frame_opts.preemph_coeff * current;
strided_input.index({"...", 0}) *= 1 - frame_opts.preemph_coeff;
}
strided_input = feature_window_function_.Apply(strided_input);
int32_t padding = frame_opts.PaddedWindowSize() - strided_input.sizes()[1]; int32_t padding = frame_opts.PaddedWindowSize() - strided_input.sizes()[1];

View File

@ -71,8 +71,12 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
// note spectrum is in magnitude, not power, because of `abs()` // note spectrum is in magnitude, not power, because of `abs()`
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs(); torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
// remove the last column, i.e., the highest fft bin
spectrum = spectrum.index(
{"...", torch::indexing::Slice(0, -1, torch::indexing::None)});
// Use power instead of magnitude if requested. // Use power instead of magnitude if requested.
if (opts_.use_power) spectrum = spectrum.pow(2); if (opts_.use_power) spectrum.pow_(2);
#if 0 #if 0
int32_t mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0); int32_t mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0);

View File

@ -51,10 +51,10 @@ FeatureWindowFunction::FeatureWindowFunction(
window = window.unsqueeze(0); window = window.unsqueeze(0);
} }
torch::Tensor FeatureWindowFunction::Apply(const torch::Tensor &input) const { void FeatureWindowFunction::Apply(torch::Tensor *wave) const {
KALDIFEAT_ASSERT(input.dim() == 2); KALDIFEAT_ASSERT(wave->dim() == 2);
KALDIFEAT_ASSERT(input.sizes()[1] == window.sizes()[1]); KALDIFEAT_ASSERT(wave->sizes()[1] == wave->sizes()[1]);
return input * window; wave->mul_(window);
} }
static int64_t FirstSampleOfFrame(int32_t frame, static int64_t FirstSampleOfFrame(int32_t frame,
@ -137,4 +137,26 @@ torch::Tensor Dither(const torch::Tensor &wave, float dither_value) {
return wave + rand_gauss * dither_value; return wave + rand_gauss * dither_value;
} }
void Preemphasize(float preemph_coeff, torch::Tensor *wave) {
if (preemph_coeff == 0.0f) return;
KALDIFEAT_ASSERT(preemph_coeff >= 0.0f && preemph_coeff <= 1.0f);
// right = wave[:, 1:]
torch::Tensor right =
wave->index({"...", torch::indexing::Slice(1, torch::indexing::None,
torch::indexing::None)});
// current = wave[:, 0:-1]
torch::Tensor current = wave->index(
{"...", torch::indexing::Slice(0, -1, torch::indexing::None)});
// wave[1:, :] = wave[:, 1:] - preemph_coeff * wave[:, 0:-1]
wave->index({"...", torch::indexing::Slice(1, torch::indexing::None,
torch::indexing::None)}) =
right - preemph_coeff * current;
wave->index({"...", 0}) *= 1 - preemph_coeff;
}
} // namespace kaldifeat } // namespace kaldifeat

View File

@ -59,7 +59,7 @@ class FeatureWindowFunction {
public: public:
FeatureWindowFunction() = default; FeatureWindowFunction() = default;
explicit FeatureWindowFunction(const FrameExtractionOptions &opts); explicit FeatureWindowFunction(const FrameExtractionOptions &opts);
torch::Tensor Apply(const torch::Tensor &input) const; void Apply(torch::Tensor *wave) const;
private: private:
torch::Tensor window; torch::Tensor window;
@ -90,6 +90,8 @@ torch::Tensor GetStrided(const torch::Tensor &wave,
torch::Tensor Dither(const torch::Tensor &wave, float dither_value); torch::Tensor Dither(const torch::Tensor &wave, float dither_value);
void Preemphasize(float preemph_coeff, torch::Tensor *wave);
} // namespace kaldifeat } // namespace kaldifeat
#endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_ #endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_

View File

@ -131,9 +131,7 @@ MelBanks::MelBanks(const MelBanksOptions &opts,
<< " and vtln-high " << vtln_high << ", versus " << " and vtln-high " << vtln_high << ", versus "
<< "low-freq " << low_freq << " and high-freq " << high_freq; << "low-freq " << low_freq << " and high-freq " << high_freq;
// TODO(fangjun): remove the last column of the power spectrum bins_mat_ = torch::zeros({num_bins, num_fft_bins}, torch::kFloat);
// and set the number of columns to num_fft_bins instead of num_fft_bins + 1
bins_mat_ = torch::zeros({num_bins, num_fft_bins + 1}, torch::kFloat);
int32_t stride = bins_mat_.strides()[0]; int32_t stride = bins_mat_.strides()[0];
for (int32_t bin = 0; bin < num_bins; ++bin) { for (int32_t bin = 0; bin < num_bins; ++bin) {
@ -177,11 +175,12 @@ MelBanks::MelBanks(const MelBanksOptions &opts,
} }
if (debug_) KALDIFEAT_LOG << bins_mat_; if (debug_) KALDIFEAT_LOG << bins_mat_;
bins_mat_.t_();
} }
torch::Tensor MelBanks::Compute(const torch::Tensor &spectrum) const { torch::Tensor MelBanks::Compute(const torch::Tensor &spectrum) const {
// TODO(fangjun): save a transposed version of `bins_mat_`. return torch::mm(spectrum, bins_mat_);
return torch::mm(spectrum, bins_mat_.t());
} }
} // namespace kaldifeat } // namespace kaldifeat

View File

@ -13,7 +13,6 @@ PYBIND11_MODULE(_kaldifeat, m) {
m.doc() = "Python wrapper for kaldifeat"; m.doc() = "Python wrapper for kaldifeat";
m.def("test", [](const torch::Tensor &tensor) -> torch::Tensor { m.def("test", [](const torch::Tensor &tensor) -> torch::Tensor {
std::cout << "size: " << tensor.sizes() << "\n";
FbankOptions fbank_opts; FbankOptions fbank_opts;
fbank_opts.frame_opts.dither = 0.0f; fbank_opts.frame_opts.dither = 0.0f;