From e793159cc7119eeb3244e90a8a6a7eeeffaa4492 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 2 Apr 2022 20:30:05 +0800 Subject: [PATCH] Add OnlinePlp Python APIs. --- kaldifeat/python/kaldifeat/__init__.py | 2 +- kaldifeat/python/kaldifeat/plp.py | 11 +++ kaldifeat/python/tests/test_plp.py | 108 ++++++++++++++++++++++--- 3 files changed, 108 insertions(+), 13 deletions(-) diff --git a/kaldifeat/python/kaldifeat/__init__.py b/kaldifeat/python/kaldifeat/__init__.py index 4b4fd50..60c6443 100644 --- a/kaldifeat/python/kaldifeat/__init__.py +++ b/kaldifeat/python/kaldifeat/__init__.py @@ -10,5 +10,5 @@ from _kaldifeat import ( from .fbank import Fbank, OnlineFbank from .mfcc import Mfcc, OnlineMfcc -from .plp import Plp +from .plp import OnlinePlp, Plp from .spectrogram import Spectrogram diff --git a/kaldifeat/python/kaldifeat/plp.py b/kaldifeat/python/kaldifeat/plp.py index 219e2d4..d99dbc2 100644 --- a/kaldifeat/python/kaldifeat/plp.py +++ b/kaldifeat/python/kaldifeat/plp.py @@ -4,9 +4,20 @@ import _kaldifeat from .offline_feature import OfflineFeature +from .online_feature import OnlineFeature class Plp(OfflineFeature): def __init__(self, opts: _kaldifeat.PlpOptions): super().__init__(opts) self.computer = _kaldifeat.Plp(opts) + + +class OnlinePlp(OnlineFeature): + def __init__(self, opts: _kaldifeat.PlpOptions): + super().__init__(opts) + self.computer = _kaldifeat.OnlinePlp(opts) + + def __setstate__(self, state): + self.opts = _kaldifeat.PlpOptions.from_dict(state) + self.computer = _kaldifeat.OnlinePlp(self.opts) diff --git a/kaldifeat/python/tests/test_plp.py b/kaldifeat/python/tests/test_plp.py index 4f20452..cf56d41 100755 --- a/kaldifeat/python/tests/test_plp.py +++ b/kaldifeat/python/tests/test_plp.py @@ -13,24 +13,82 @@ import kaldifeat cur_dir = Path(__file__).resolve().parent +def test_online_plp( + opts: kaldifeat.PlpOptions, + wave: torch.Tensor, + cpu_features: torch.Tensor, +): + """ + Args: + opts: + The options to create the online plp extractor. + wave: + The input 1-D waveform. + cpu_features: + The groud truth features that are computed offline + """ + online_plp = kaldifeat.OnlinePlp(opts) + + num_processed_frames = 0 + i = 0 # current sample index to feed + while not online_plp.is_last_frame(num_processed_frames - 1): + while num_processed_frames < online_plp.num_frames_ready: + # There are new frames to be processed + frame = online_plp.get_frame(num_processed_frames) + assert torch.allclose( + frame.squeeze(0), cpu_features[num_processed_frames], atol=1e-3 + ) + num_processed_frames += 1 + + # Simulate streaming . Send a random number of audio samples + # to the extractor + num_samples = torch.randint(300, 1000, (1,)).item() + + samples = wave[i : (i + num_samples)] # noqa + i += num_samples + if len(samples) == 0: + online_plp.input_finished() + continue + + online_plp.accept_waveform(16000, samples) + + assert num_processed_frames == online_plp.num_frames_ready + assert num_processed_frames == cpu_features.size(0) + + def test_plp_default(): print("=====test_plp_default=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-plp.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.PlpOptions() opts.frame_opts.dither = 0 opts.device = device plp = kaldifeat.Plp(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename).to(device) - features = plp(wave) - gt = read_ark_txt(cur_dir / "test_data/test-plp.txt") + features = plp(wave.to(device)) + if device.type == "cpu": + cpu_features = features + assert torch.allclose(features.cpu(), gt, rtol=1e-1) + opts = kaldifeat.PlpOptions() + opts.frame_opts.dither = 0 + + test_online_plp(opts, wave, cpu_features) + def test_plp_no_snip_edges(): print("=====test_plp_no_snip_edges=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-plp-no-snip-edges.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.PlpOptions() @@ -39,16 +97,26 @@ def test_plp_no_snip_edges(): opts.frame_opts.snip_edges = False plp = kaldifeat.Plp(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename).to(device) - features = plp(wave) - gt = read_ark_txt(cur_dir / "test_data/test-plp-no-snip-edges.txt") + features = plp(wave.to(device)) + if device.type == "cpu": + cpu_features = features assert torch.allclose(features.cpu(), gt, atol=1e-1) + opts = kaldifeat.PlpOptions() + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + + test_online_plp(opts, wave, cpu_features) + def test_plp_htk_10_ceps(): print("=====test_plp_htk_10_ceps=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-plp-htk-10-ceps.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.PlpOptions() @@ -58,13 +126,19 @@ def test_plp_htk_10_ceps(): opts.frame_opts.dither = 0 plp = kaldifeat.Plp(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename).to(device) - features = plp(wave) - gt = read_ark_txt(cur_dir / "test_data/test-plp-htk-10-ceps.txt") + features = plp(wave.to(device)) + if device.type == "cpu": + cpu_features = features assert torch.allclose(features.cpu(), gt, atol=1e-1) + opts = kaldifeat.PlpOptions() + opts.htk_compat = True + opts.num_ceps = 10 + opts.frame_opts.dither = 0 + + test_online_plp(opts, wave, cpu_features) + def test_pickle(): for device in get_devices(): @@ -79,6 +153,16 @@ def test_pickle(): assert str(plp.opts) == str(plp2.opts) + opts = kaldifeat.PlpOptions() + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + + plp = kaldifeat.OnlinePlp(opts) + data = pickle.dumps(plp) + plp2 = pickle.loads(data) + + assert str(plp.opts) == str(plp2.opts) + if __name__ == "__main__": test_plp_default()