From 7ae06d78ebe80490ba0cc3b6b6f532899cad554b Mon Sep 17 00:00:00 2001 From: Feiteng Date: Wed, 24 Jan 2024 15:35:57 +0800 Subject: [PATCH] Init pitch test --- kaldifeat/python/tests/CMakeLists.txt | 1 + kaldifeat/python/tests/test_pitch.py | 41 +++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100755 kaldifeat/python/tests/test_pitch.py diff --git a/kaldifeat/python/tests/CMakeLists.txt b/kaldifeat/python/tests/CMakeLists.txt index 4ccc891..2d6151a 100644 --- a/kaldifeat/python/tests/CMakeLists.txt +++ b/kaldifeat/python/tests/CMakeLists.txt @@ -23,6 +23,7 @@ set(py_test_files test_mel_bank_options.py test_mfcc.py test_mfcc_options.py + test_pitch.py test_plp.py test_plp_options.py test_spectrogram.py diff --git a/kaldifeat/python/tests/test_pitch.py b/kaldifeat/python/tests/test_pitch.py new file mode 100755 index 0000000..56d60a8 --- /dev/null +++ b/kaldifeat/python/tests/test_pitch.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang) + +import pickle +from pathlib import Path + +import torch +from utils import get_devices, read_ark_txt, read_wave + +import kaldifeat + +cur_dir = Path(__file__).resolve().parent + + +def test_pitch_default(): + print("=====test_pitch_default=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-pitch.txt") + + cpu_features = None + for device in get_devices(): + print("device", device) + opts = kaldifeat.PitchOptions() + opts.device = device + opts.frame_opts.dither = 0 + pitch = kaldifeat.Pitch(opts) + + features = pitch(wave) + assert features.device.type == "cpu" + assert torch.allclose(features, gt, rtol=1e-4) + if cpu_features is None: + cpu_features = features + + features = pitch(wave.to(device)) + assert features.device == device + assert torch.allclose(features.cpu(), gt, rtol=1e-4) + +if __name__ == "__main__": + test_pitch_default()