kaldifeat/kaldifeat/csrc/feature-window.cc
2022-04-02 20:03:42 +08:00

248 lines
9.1 KiB
C++

// kaldifeat/csrc/feature-window.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-window.cc
#include "kaldifeat/csrc/feature-window.h"
#include <cmath>
#include <vector>
#ifndef M_2PI
#define M_2PI 6.283185307179586476925286766559005
#endif
namespace kaldifeat {
std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts) {
os << opts.ToString();
return os;
}
FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts,
torch::Device device) {
int32_t frame_length = opts.WindowSize();
KALDIFEAT_ASSERT(frame_length > 0);
window = torch::empty({frame_length}, torch::kFloat32);
float *window_data = window.data_ptr<float>();
double a = M_2PI / (frame_length - 1);
for (int32_t i = 0; i < frame_length; i++) {
double i_fl = static_cast<double>(i);
if (opts.window_type == "hanning") {
window_data[i] = 0.5 - 0.5 * cos(a * i_fl);
} else if (opts.window_type == "sine") {
// when you are checking ws wikipedia, please
// note that 0.5 * a = M_PI/(frame_length-1)
window_data[i] = sin(0.5 * a * i_fl);
} else if (opts.window_type == "hamming") {
window_data[i] = 0.54 - 0.46 * cos(a * i_fl);
} else if (opts.window_type ==
"povey") { // like hamming but goes to zero at edges.
window_data[i] = pow(0.5 - 0.5 * cos(a * i_fl), 0.85);
} else if (opts.window_type == "rectangular") {
window_data[i] = 1.0;
} else if (opts.window_type == "blackman") {
window_data[i] = opts.blackman_coeff - 0.5 * cos(a * i_fl) +
(0.5 - opts.blackman_coeff) * cos(2 * a * i_fl);
} else {
KALDIFEAT_ERR << "Invalid window type " << opts.window_type;
}
}
window = window.unsqueeze(0);
if (window.device() != device) {
window = window.to(device);
}
}
torch::Tensor FeatureWindowFunction::Apply(const torch::Tensor &wave) const {
KALDIFEAT_ASSERT(wave.dim() == 2);
KALDIFEAT_ASSERT(wave.size(1) == window.size(1));
return wave.mul(window);
}
int64_t FirstSampleOfFrame(int32_t frame, const FrameExtractionOptions &opts) {
int64_t frame_shift = opts.WindowShift();
if (opts.snip_edges) {
return frame * frame_shift;
} else {
int64_t midpoint_of_frame = frame_shift * frame + frame_shift / 2,
beginning_of_frame = midpoint_of_frame - opts.WindowSize() / 2;
return beginning_of_frame;
}
}
int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts,
bool flush /*= true*/) {
int64_t frame_shift = opts.WindowShift();
int64_t frame_length = opts.WindowSize();
if (opts.snip_edges) {
// with --snip-edges=true (the default), we use a HTK-like approach to
// determining the number of frames-- all frames have to fit completely into
// the waveform, and the first frame begins at sample zero.
if (num_samples < frame_length)
return 0;
else
return (1 + ((num_samples - frame_length) / frame_shift));
// You can understand the expression above as follows: 'num_samples -
// frame_length' is how much room we have to shift the frame within the
// waveform; 'frame_shift' is how much we shift it each time; and the ratio
// is how many times we can shift it (integer arithmetic rounds down).
} else {
// if --snip-edges=false, the number of frames is determined by rounding the
// (file-length / frame-shift) to the nearest integer. The point of this
// formula is to make the number of frames an obvious and predictable
// function of the frame shift and signal length, which makes many
// segmentation-related questions simpler.
//
// Because integer division in C++ rounds toward zero, we add (half the
// frame-shift minus epsilon) before dividing, to have the effect of
// rounding towards the closest integer.
int32_t num_frames = (num_samples + (frame_shift / 2)) / frame_shift;
if (flush) return num_frames;
// note: 'end' always means the last plus one, i.e. one past the last.
int64_t end_sample_of_last_frame =
FirstSampleOfFrame(num_frames - 1, opts) + frame_length;
// the following code is optimized more for clarity than efficiency.
// If flush == false, we can't output frames that extend past the end
// of the signal.
while (num_frames > 0 && end_sample_of_last_frame > num_samples) {
num_frames--;
end_sample_of_last_frame -= frame_shift;
}
return num_frames;
}
}
torch::Tensor GetStrided(const torch::Tensor &wave,
const FrameExtractionOptions &opts) {
KALDIFEAT_ASSERT(wave.dim() == 1);
std::vector<int64_t> strides = {opts.WindowShift() * wave.strides()[0],
wave.strides()[0]};
int64_t num_samples = wave.size(0);
int32_t num_frames = NumFrames(num_samples, opts);
std::vector<int64_t> sizes = {num_frames, opts.WindowSize()};
if (opts.snip_edges) {
return wave.as_strided(sizes, strides);
}
int32_t frame_length = opts.samp_freq / 1000 * opts.frame_length_ms;
int32_t frame_shift = opts.samp_freq / 1000 * opts.frame_shift_ms;
int64_t num_new_samples = (num_frames - 1) * frame_shift + frame_length;
int32_t num_padding = num_new_samples - num_samples;
int32_t num_left_padding = (frame_length - frame_shift) / 2;
int32_t num_right_padding = num_padding - num_left_padding;
// left_padding = wave[:num_left_padding].flip(dims=(0,))
torch::Tensor left_padding =
wave.index({torch::indexing::Slice(0, num_left_padding, 1)}).flip({0});
// right_padding = wave[-num_righ_padding:].flip(dims=(0,))
torch::Tensor right_padding =
wave.index({torch::indexing::Slice(-num_right_padding,
torch::indexing::None, 1)})
.flip({0});
torch::Tensor new_wave = torch::cat({left_padding, wave, right_padding}, 0);
return new_wave.as_strided(sizes, strides);
}
torch::Tensor Dither(const torch::Tensor &wave, float dither_value) {
if (dither_value == 0.0f) return wave;
torch::Tensor rand_gauss = torch::randn_like(wave);
#if 1
return wave + rand_gauss * dither_value;
#else
// use in-place version of wave and change it to pointer type
wave_->add_(rand_gauss, dither_value);
#endif
}
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) {
if (preemph_coeff == 0.0f) return wave;
KALDIFEAT_ASSERT(preemph_coeff >= 0.0f && preemph_coeff <= 1.0f);
torch::Tensor ans = torch::empty_like(wave);
using torch::indexing::None;
using torch::indexing::Slice;
// right = wave[:, 1:]
torch::Tensor right = wave.index({"...", Slice(1, None, None)});
// current = wave[:, 0:-1]
torch::Tensor current = wave.index({"...", Slice(0, -1, None)});
// ans[1:, :] = wave[:, 1:] - preemph_coeff * wave[:, 0:-1]
ans.index({"...", Slice(1, None, None)}) = right - preemph_coeff * current;
ans.index({"...", 0}) = wave.index({"...", 0}) * (1 - preemph_coeff);
return ans;
}
torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
int32_t f, const FrameExtractionOptions &opts) {
KALDIFEAT_ASSERT(sample_offset >= 0 && wave.numel() != 0);
int32_t frame_length = opts.WindowSize();
int64_t num_samples = sample_offset + wave.numel();
int64_t start_sample = FirstSampleOfFrame(f, opts);
int64_t end_sample = start_sample + frame_length;
if (opts.snip_edges) {
KALDIFEAT_ASSERT(start_sample >= sample_offset &&
end_sample <= num_samples);
} else {
KALDIFEAT_ASSERT(sample_offset == 0 || start_sample >= sample_offset);
}
// wave_start and wave_end are start and end indexes into 'wave', for the
// piece of wave that we're trying to extract.
int32_t wave_start = static_cast<int32_t>(start_sample - sample_offset);
int32_t wave_end = wave_start + frame_length;
if (wave_start >= 0 && wave_end <= wave.numel()) {
// the normal case -- no edge effects to consider.
// return wave[wave_start:wave_end]
return wave.index({torch::indexing::Slice(wave_start, wave_end)});
} else {
torch::Tensor window = torch::empty({frame_length}, torch::kFloat);
auto p_window = window.accessor<float, 1>();
auto p_wave = wave.accessor<float, 1>();
// Deal with any end effects by reflection, if needed. This code will only
// be reached for about two frames per utterance, so we don't concern
// ourselves excessively with efficiency.
int32_t wave_dim = wave.numel();
for (int32_t s = 0; s != frame_length; ++s) {
int32_t s_in_wave = s + wave_start;
while (s_in_wave < 0 || s_in_wave >= wave_dim) {
// reflect around the beginning or end of the wave.
// e.g. -1 -> 0, -2 -> 1.
// dim -> dim - 1, dim + 1 -> dim - 2.
// the code supports repeated reflections, although this
// would only be needed in pathological cases.
if (s_in_wave < 0) {
s_in_wave = -s_in_wave - 1;
} else {
s_in_wave = 2 * wave_dim - 1 - s_in_wave;
}
}
p_window[s] = p_wave[s_in_wave];
}
return window;
}
}
} // namespace kaldifeat