mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-12 11:32:17 +00:00
add test for the default parameters.
This commit is contained in:
parent
917df77946
commit
1ccc03a781
@ -20,15 +20,14 @@ torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
|
|||||||
|
|
||||||
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
|
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
|
||||||
|
|
||||||
// TODO(fangjun): avoid clone
|
torch::Tensor strided_input = GetStrided(wave, frame_opts);
|
||||||
torch::Tensor strided_input = GetStrided(wave, frame_opts).clone();
|
|
||||||
|
|
||||||
if (frame_opts.dither != 0)
|
if (frame_opts.dither != 0)
|
||||||
strided_input = Dither(strided_input, frame_opts.dither);
|
strided_input = Dither(strided_input, frame_opts.dither);
|
||||||
|
|
||||||
if (frame_opts.remove_dc_offset) {
|
if (frame_opts.remove_dc_offset) {
|
||||||
torch::Tensor row_means = strided_input.mean(1).unsqueeze(1);
|
torch::Tensor row_means = strided_input.mean(1).unsqueeze(1);
|
||||||
strided_input -= row_means;
|
strided_input = strided_input - row_means;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool use_raw_log_energy = computer_.NeedRawLogEnergy();
|
bool use_raw_log_energy = computer_.NeedRawLogEnergy();
|
||||||
|
@ -36,7 +36,34 @@ static void TestPad() {
|
|||||||
std::cout << b << "\n";
|
std::cout << b << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void TestGetStrided() {
|
||||||
|
// 0 1 2 3 4 5
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// 0 1 2 3
|
||||||
|
// 2 3 4 5
|
||||||
|
|
||||||
|
torch::Tensor a = torch::arange(0, 6).to(torch::kFloat);
|
||||||
|
torch::Tensor b = a.as_strided({2, 4}, {2, 1});
|
||||||
|
// b = b.clone();
|
||||||
|
std::cout << a << "\n";
|
||||||
|
std::cout << b << "\n";
|
||||||
|
std::cout << b.mean(1).unsqueeze(1) << "\n";
|
||||||
|
b = b - b.mean(1).unsqueeze(1);
|
||||||
|
std::cout << a << "\n";
|
||||||
|
std::cout << b << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
static void TestDither() {
|
||||||
|
torch::Tensor a = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat);
|
||||||
|
torch::Tensor b = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat) * 0.1;
|
||||||
|
std::cout << a << "\n";
|
||||||
|
std::cout << b << "\n";
|
||||||
|
std::cout << (a + b * 2) << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
TestPad();
|
// TestDither();
|
||||||
|
TestGetStrided();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -14,28 +14,29 @@ namespace kaldifeat {
|
|||||||
PYBIND11_MODULE(_kaldifeat, m) {
|
PYBIND11_MODULE(_kaldifeat, m) {
|
||||||
m.doc() = "Python wrapper for kaldifeat";
|
m.doc() = "Python wrapper for kaldifeat";
|
||||||
|
|
||||||
m.def("test", [](const torch::Tensor &tensor) -> torch::Tensor {
|
// It verifies that the reimplementation produces the same output
|
||||||
FbankOptions fbank_opts;
|
// as kaldi using default paremters with dither disabled.
|
||||||
fbank_opts.frame_opts.dither = 0.0f;
|
m.def("test_default_parameters",
|
||||||
|
[](const torch::Tensor &tensor) -> std::pair<torch::Tensor, double> {
|
||||||
|
FbankOptions fbank_opts;
|
||||||
|
fbank_opts.frame_opts.dither = 0.0f;
|
||||||
|
|
||||||
Fbank fbank(fbank_opts);
|
Fbank fbank(fbank_opts);
|
||||||
float vtln_warp = 1.0f;
|
float vtln_warp = 1.0f;
|
||||||
|
|
||||||
std::chrono::steady_clock::time_point begin =
|
std::chrono::steady_clock::time_point begin =
|
||||||
std::chrono::steady_clock::now();
|
std::chrono::steady_clock::now();
|
||||||
|
|
||||||
torch::Tensor ans = fbank.ComputeFeatures(tensor, vtln_warp);
|
torch::Tensor ans = fbank.ComputeFeatures(tensor, vtln_warp);
|
||||||
std::chrono::steady_clock::time_point end =
|
std::chrono::steady_clock::time_point end =
|
||||||
std::chrono::steady_clock::now();
|
std::chrono::steady_clock::now();
|
||||||
std::cout << "Time difference = "
|
double elapsed_seconds =
|
||||||
<< std::chrono::duration_cast<std::chrono::microseconds>(end -
|
std::chrono::duration_cast<std::chrono::microseconds>(end - begin)
|
||||||
begin)
|
.count() /
|
||||||
.count() /
|
1000000.;
|
||||||
1000000.
|
|
||||||
<< "[s]" << std::endl;
|
|
||||||
|
|
||||||
return ans;
|
return std::make_pair(ans, elapsed_seconds);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace kaldifeat
|
} // namespace kaldifeat
|
||||||
|
@ -1,31 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
#
|
|
||||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.path.insert(0, '/root/fangjun/open-source/kaldifeat/build/lib')
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import soundfile as sf
|
|
||||||
import _kaldifeat
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# sox -n -r 16000 -b 16 abc.wav synth 1 sine 100
|
|
||||||
filename = '/root/fangjun/open-source/kaldi/src/featbin/abc.wav'
|
|
||||||
with sf.SoundFile(filename) as sf_desc:
|
|
||||||
sampling_rate = sf_desc.samplerate
|
|
||||||
assert sampling_rate == 16000
|
|
||||||
a = sf_desc.read(dtype=np.float32, always_2d=False)
|
|
||||||
a *= 32768
|
|
||||||
tensor = torch.from_numpy(a)
|
|
||||||
ans = _kaldifeat.test(tensor)
|
|
||||||
# torch.set_printoptions(profile="full")
|
|
||||||
# print(ans.shape)
|
|
||||||
# print(ans)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
23
kaldifeat/python/tests/test_data/README.md
Normal file
23
kaldifeat/python/tests/test_data/README.md
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
|
||||||
|
# File descriptions
|
||||||
|
|
||||||
|
## abc.wav
|
||||||
|
|
||||||
|
It is generated by
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sox -n -r 16000 -b 16 abc.wav synth 320 sine 300
|
||||||
|
```
|
||||||
|
|
||||||
|
## abc.scp and abc.txt
|
||||||
|
|
||||||
|
It is generated by
|
||||||
|
|
||||||
|
```
|
||||||
|
echo "1 abc.wav" > abc.scp
|
||||||
|
kaldi/src/featbin$ ./compute-fbank-feats --dither=0 scp:abc.scp ark,t:abc.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## abc.diff
|
||||||
|
|
||||||
|
It's the change that measures only the feature computation time, excluding the I/O time.
|
30
kaldifeat/python/tests/test_data/abc.diff
Normal file
30
kaldifeat/python/tests/test_data/abc.diff
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
diff --git a/src/featbin/compute-fbank-feats.cc b/src/featbin/compute-fbank-feats.cc
|
||||||
|
index e52b30baf..63735c985 100644
|
||||||
|
--- a/src/featbin/compute-fbank-feats.cc
|
||||||
|
+++ b/src/featbin/compute-fbank-feats.cc
|
||||||
|
@@ -18,6 +18,8 @@
|
||||||
|
// See the Apache 2 License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
+#include <chrono>
|
||||||
|
+
|
||||||
|
#include "base/kaldi-common.h"
|
||||||
|
#include "feat/feature-fbank.h"
|
||||||
|
#include "feat/wave-reader.h"
|
||||||
|
@@ -145,8 +147,16 @@ int main(int argc, char *argv[]) {
|
||||||
|
SubVector<BaseFloat> waveform(wave_data.Data(), this_chan);
|
||||||
|
Matrix<BaseFloat> features;
|
||||||
|
try {
|
||||||
|
+ std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
|
||||||
|
fbank.ComputeFeatures(waveform, wave_data.SampFreq(),
|
||||||
|
vtln_warp_local, &features);
|
||||||
|
+ std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
|
||||||
|
+ std::cout << "Time difference = "
|
||||||
|
+ << std::chrono::duration_cast<std::chrono::microseconds>(
|
||||||
|
+ end - begin)
|
||||||
|
+ .count() /
|
||||||
|
+ 1000000.
|
||||||
|
+ << "[s]" << std::endl;
|
||||||
|
} catch (...) {
|
||||||
|
KALDI_WARN << "Failed to compute features for utterance " << utt;
|
||||||
|
continue;
|
1
kaldifeat/python/tests/test_data/abc.scp
Normal file
1
kaldifeat/python/tests/test_data/abc.scp
Normal file
@ -0,0 +1 @@
|
|||||||
|
1 abc.wav
|
31999
kaldifeat/python/tests/test_data/abc.txt
Normal file
31999
kaldifeat/python/tests/test_data/abc.txt
Normal file
File diff suppressed because it is too large
Load Diff
BIN
kaldifeat/python/tests/test_data/abc.wav
Normal file
BIN
kaldifeat/python/tests/test_data/abc.wav
Normal file
Binary file not shown.
16
kaldifeat/python/tests/test_data/run.sh
Executable file
16
kaldifeat/python/tests/test_data/run.sh
Executable file
@ -0,0 +1,16 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
#
|
||||||
|
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
KALDI_ROOT=/root/fangjun/open-source/kaldi
|
||||||
|
export PATH=${KALDI_ROOT}/src/featbin:$PATH
|
||||||
|
|
||||||
|
sox -n -r 16000 -b 16 abc.wav synth 320 sine 300
|
||||||
|
echo "1 abc.wav" > abc.scp
|
||||||
|
compute-fbank-feats --dither=0 scp:abc.scp ark,t:abc.txt
|
||||||
|
|
||||||
|
# Output
|
||||||
|
#
|
||||||
|
# compute-fbank-feats --dither=0 scp:abc.scp ark,t:abc.txt
|
||||||
|
# Time difference = 0.304916[s]
|
||||||
|
# LOG (compute-fbank-feats[5.5.880~4-3e446]:main():compute-fbank-feats.cc:195) Done 1 out of 1 utterances.
|
48
kaldifeat/python/tests/test_default_parameters.py
Executable file
48
kaldifeat/python/tests/test_default_parameters.py
Executable file
@ -0,0 +1,48 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
cur_dir = Path(__file__).resolve().parent
|
||||||
|
kaldi_feat_dir = cur_dir.parent.parent.parent
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, f'{kaldi_feat_dir}/build/lib')
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import _kaldifeat
|
||||||
|
|
||||||
|
|
||||||
|
def read_ark_txt() -> torch.Tensor:
|
||||||
|
test_data_dir = cur_dir / 'test_data'
|
||||||
|
filename = test_data_dir / 'abc.txt'
|
||||||
|
features = []
|
||||||
|
with open(filename) as f:
|
||||||
|
for line in f:
|
||||||
|
if '[' in line: continue
|
||||||
|
line = line.strip('').split()
|
||||||
|
data = [float(d) for d in line if d != ']']
|
||||||
|
features.append(data)
|
||||||
|
ans = torch.tensor(features)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_data_dir = cur_dir / 'test_data'
|
||||||
|
filename = test_data_dir / 'abc.wav'
|
||||||
|
with sf.SoundFile(filename) as sf_desc:
|
||||||
|
sampling_rate = sf_desc.samplerate
|
||||||
|
assert sampling_rate == 16000
|
||||||
|
data = sf_desc.read(dtype=np.float32, always_2d=False)
|
||||||
|
data *= 32768
|
||||||
|
tensor = torch.from_numpy(data)
|
||||||
|
ans, elapsed_seconds = _kaldifeat.test_default_parameters(tensor)
|
||||||
|
expected = read_ark_txt()
|
||||||
|
assert torch.allclose(ans, expected, rtol=1e-3)
|
||||||
|
print('elapsed seconds:', elapsed_seconds)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user