mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-10 02:22:16 +00:00
add test for the default parameters.
This commit is contained in:
parent
917df77946
commit
1ccc03a781
@ -20,15 +20,14 @@ torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
|
||||
|
||||
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
|
||||
|
||||
// TODO(fangjun): avoid clone
|
||||
torch::Tensor strided_input = GetStrided(wave, frame_opts).clone();
|
||||
torch::Tensor strided_input = GetStrided(wave, frame_opts);
|
||||
|
||||
if (frame_opts.dither != 0)
|
||||
strided_input = Dither(strided_input, frame_opts.dither);
|
||||
|
||||
if (frame_opts.remove_dc_offset) {
|
||||
torch::Tensor row_means = strided_input.mean(1).unsqueeze(1);
|
||||
strided_input -= row_means;
|
||||
strided_input = strided_input - row_means;
|
||||
}
|
||||
|
||||
bool use_raw_log_energy = computer_.NeedRawLogEnergy();
|
||||
|
@ -36,7 +36,34 @@ static void TestPad() {
|
||||
std::cout << b << "\n";
|
||||
}
|
||||
|
||||
static void TestGetStrided() {
|
||||
// 0 1 2 3 4 5
|
||||
//
|
||||
//
|
||||
// 0 1 2 3
|
||||
// 2 3 4 5
|
||||
|
||||
torch::Tensor a = torch::arange(0, 6).to(torch::kFloat);
|
||||
torch::Tensor b = a.as_strided({2, 4}, {2, 1});
|
||||
// b = b.clone();
|
||||
std::cout << a << "\n";
|
||||
std::cout << b << "\n";
|
||||
std::cout << b.mean(1).unsqueeze(1) << "\n";
|
||||
b = b - b.mean(1).unsqueeze(1);
|
||||
std::cout << a << "\n";
|
||||
std::cout << b << "\n";
|
||||
}
|
||||
|
||||
static void TestDither() {
|
||||
torch::Tensor a = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat);
|
||||
torch::Tensor b = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat) * 0.1;
|
||||
std::cout << a << "\n";
|
||||
std::cout << b << "\n";
|
||||
std::cout << (a + b * 2) << "\n";
|
||||
}
|
||||
|
||||
int main() {
|
||||
TestPad();
|
||||
// TestDither();
|
||||
TestGetStrided();
|
||||
return 0;
|
||||
}
|
||||
|
@ -14,28 +14,29 @@ namespace kaldifeat {
|
||||
PYBIND11_MODULE(_kaldifeat, m) {
|
||||
m.doc() = "Python wrapper for kaldifeat";
|
||||
|
||||
m.def("test", [](const torch::Tensor &tensor) -> torch::Tensor {
|
||||
FbankOptions fbank_opts;
|
||||
fbank_opts.frame_opts.dither = 0.0f;
|
||||
// It verifies that the reimplementation produces the same output
|
||||
// as kaldi using default paremters with dither disabled.
|
||||
m.def("test_default_parameters",
|
||||
[](const torch::Tensor &tensor) -> std::pair<torch::Tensor, double> {
|
||||
FbankOptions fbank_opts;
|
||||
fbank_opts.frame_opts.dither = 0.0f;
|
||||
|
||||
Fbank fbank(fbank_opts);
|
||||
float vtln_warp = 1.0f;
|
||||
Fbank fbank(fbank_opts);
|
||||
float vtln_warp = 1.0f;
|
||||
|
||||
std::chrono::steady_clock::time_point begin =
|
||||
std::chrono::steady_clock::now();
|
||||
std::chrono::steady_clock::time_point begin =
|
||||
std::chrono::steady_clock::now();
|
||||
|
||||
torch::Tensor ans = fbank.ComputeFeatures(tensor, vtln_warp);
|
||||
std::chrono::steady_clock::time_point end =
|
||||
std::chrono::steady_clock::now();
|
||||
std::cout << "Time difference = "
|
||||
<< std::chrono::duration_cast<std::chrono::microseconds>(end -
|
||||
begin)
|
||||
.count() /
|
||||
1000000.
|
||||
<< "[s]" << std::endl;
|
||||
torch::Tensor ans = fbank.ComputeFeatures(tensor, vtln_warp);
|
||||
std::chrono::steady_clock::time_point end =
|
||||
std::chrono::steady_clock::now();
|
||||
double elapsed_seconds =
|
||||
std::chrono::duration_cast<std::chrono::microseconds>(end - begin)
|
||||
.count() /
|
||||
1000000.;
|
||||
|
||||
return ans;
|
||||
});
|
||||
return std::make_pair(ans, elapsed_seconds);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace kaldifeat
|
||||
|
@ -1,31 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, '/root/fangjun/open-source/kaldifeat/build/lib')
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import _kaldifeat
|
||||
|
||||
|
||||
def main():
|
||||
# sox -n -r 16000 -b 16 abc.wav synth 1 sine 100
|
||||
filename = '/root/fangjun/open-source/kaldi/src/featbin/abc.wav'
|
||||
with sf.SoundFile(filename) as sf_desc:
|
||||
sampling_rate = sf_desc.samplerate
|
||||
assert sampling_rate == 16000
|
||||
a = sf_desc.read(dtype=np.float32, always_2d=False)
|
||||
a *= 32768
|
||||
tensor = torch.from_numpy(a)
|
||||
ans = _kaldifeat.test(tensor)
|
||||
# torch.set_printoptions(profile="full")
|
||||
# print(ans.shape)
|
||||
# print(ans)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
23
kaldifeat/python/tests/test_data/README.md
Normal file
23
kaldifeat/python/tests/test_data/README.md
Normal file
@ -0,0 +1,23 @@
|
||||
|
||||
# File descriptions
|
||||
|
||||
## abc.wav
|
||||
|
||||
It is generated by
|
||||
|
||||
```bash
|
||||
sox -n -r 16000 -b 16 abc.wav synth 320 sine 300
|
||||
```
|
||||
|
||||
## abc.scp and abc.txt
|
||||
|
||||
It is generated by
|
||||
|
||||
```
|
||||
echo "1 abc.wav" > abc.scp
|
||||
kaldi/src/featbin$ ./compute-fbank-feats --dither=0 scp:abc.scp ark,t:abc.txt
|
||||
```
|
||||
|
||||
## abc.diff
|
||||
|
||||
It's the change that measures only the feature computation time, excluding the I/O time.
|
30
kaldifeat/python/tests/test_data/abc.diff
Normal file
30
kaldifeat/python/tests/test_data/abc.diff
Normal file
@ -0,0 +1,30 @@
|
||||
diff --git a/src/featbin/compute-fbank-feats.cc b/src/featbin/compute-fbank-feats.cc
|
||||
index e52b30baf..63735c985 100644
|
||||
--- a/src/featbin/compute-fbank-feats.cc
|
||||
+++ b/src/featbin/compute-fbank-feats.cc
|
||||
@@ -18,6 +18,8 @@
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
+#include <chrono>
|
||||
+
|
||||
#include "base/kaldi-common.h"
|
||||
#include "feat/feature-fbank.h"
|
||||
#include "feat/wave-reader.h"
|
||||
@@ -145,8 +147,16 @@ int main(int argc, char *argv[]) {
|
||||
SubVector<BaseFloat> waveform(wave_data.Data(), this_chan);
|
||||
Matrix<BaseFloat> features;
|
||||
try {
|
||||
+ std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
|
||||
fbank.ComputeFeatures(waveform, wave_data.SampFreq(),
|
||||
vtln_warp_local, &features);
|
||||
+ std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
|
||||
+ std::cout << "Time difference = "
|
||||
+ << std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
+ end - begin)
|
||||
+ .count() /
|
||||
+ 1000000.
|
||||
+ << "[s]" << std::endl;
|
||||
} catch (...) {
|
||||
KALDI_WARN << "Failed to compute features for utterance " << utt;
|
||||
continue;
|
1
kaldifeat/python/tests/test_data/abc.scp
Normal file
1
kaldifeat/python/tests/test_data/abc.scp
Normal file
@ -0,0 +1 @@
|
||||
1 abc.wav
|
31999
kaldifeat/python/tests/test_data/abc.txt
Normal file
31999
kaldifeat/python/tests/test_data/abc.txt
Normal file
File diff suppressed because it is too large
Load Diff
BIN
kaldifeat/python/tests/test_data/abc.wav
Normal file
BIN
kaldifeat/python/tests/test_data/abc.wav
Normal file
Binary file not shown.
16
kaldifeat/python/tests/test_data/run.sh
Executable file
16
kaldifeat/python/tests/test_data/run.sh
Executable file
@ -0,0 +1,16 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
KALDI_ROOT=/root/fangjun/open-source/kaldi
|
||||
export PATH=${KALDI_ROOT}/src/featbin:$PATH
|
||||
|
||||
sox -n -r 16000 -b 16 abc.wav synth 320 sine 300
|
||||
echo "1 abc.wav" > abc.scp
|
||||
compute-fbank-feats --dither=0 scp:abc.scp ark,t:abc.txt
|
||||
|
||||
# Output
|
||||
#
|
||||
# compute-fbank-feats --dither=0 scp:abc.scp ark,t:abc.txt
|
||||
# Time difference = 0.304916[s]
|
||||
# LOG (compute-fbank-feats[5.5.880~4-3e446]:main():compute-fbank-feats.cc:195) Done 1 out of 1 utterances.
|
48
kaldifeat/python/tests/test_default_parameters.py
Executable file
48
kaldifeat/python/tests/test_default_parameters.py
Executable file
@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
from pathlib import Path
|
||||
cur_dir = Path(__file__).resolve().parent
|
||||
kaldi_feat_dir = cur_dir.parent.parent.parent
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, f'{kaldi_feat_dir}/build/lib')
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import _kaldifeat
|
||||
|
||||
|
||||
def read_ark_txt() -> torch.Tensor:
|
||||
test_data_dir = cur_dir / 'test_data'
|
||||
filename = test_data_dir / 'abc.txt'
|
||||
features = []
|
||||
with open(filename) as f:
|
||||
for line in f:
|
||||
if '[' in line: continue
|
||||
line = line.strip('').split()
|
||||
data = [float(d) for d in line if d != ']']
|
||||
features.append(data)
|
||||
ans = torch.tensor(features)
|
||||
return ans
|
||||
|
||||
|
||||
def main():
|
||||
test_data_dir = cur_dir / 'test_data'
|
||||
filename = test_data_dir / 'abc.wav'
|
||||
with sf.SoundFile(filename) as sf_desc:
|
||||
sampling_rate = sf_desc.samplerate
|
||||
assert sampling_rate == 16000
|
||||
data = sf_desc.read(dtype=np.float32, always_2d=False)
|
||||
data *= 32768
|
||||
tensor = torch.from_numpy(data)
|
||||
ans, elapsed_seconds = _kaldifeat.test_default_parameters(tensor)
|
||||
expected = read_ark_txt()
|
||||
assert torch.allclose(ans, expected, rtol=1e-3)
|
||||
print('elapsed seconds:', elapsed_seconds)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user