Add OnlinePlp Python APIs.

This commit is contained in:
Fangjun Kuang 2022-04-02 20:30:05 +08:00
parent 8f03b654fc
commit e793159cc7
3 changed files with 108 additions and 13 deletions

View File

@ -10,5 +10,5 @@ from _kaldifeat import (
from .fbank import Fbank, OnlineFbank from .fbank import Fbank, OnlineFbank
from .mfcc import Mfcc, OnlineMfcc from .mfcc import Mfcc, OnlineMfcc
from .plp import Plp from .plp import OnlinePlp, Plp
from .spectrogram import Spectrogram from .spectrogram import Spectrogram

View File

@ -4,9 +4,20 @@
import _kaldifeat import _kaldifeat
from .offline_feature import OfflineFeature from .offline_feature import OfflineFeature
from .online_feature import OnlineFeature
class Plp(OfflineFeature): class Plp(OfflineFeature):
def __init__(self, opts: _kaldifeat.PlpOptions): def __init__(self, opts: _kaldifeat.PlpOptions):
super().__init__(opts) super().__init__(opts)
self.computer = _kaldifeat.Plp(opts) self.computer = _kaldifeat.Plp(opts)
class OnlinePlp(OnlineFeature):
def __init__(self, opts: _kaldifeat.PlpOptions):
super().__init__(opts)
self.computer = _kaldifeat.OnlinePlp(opts)
def __setstate__(self, state):
self.opts = _kaldifeat.PlpOptions.from_dict(state)
self.computer = _kaldifeat.OnlinePlp(self.opts)

View File

@ -13,24 +13,82 @@ import kaldifeat
cur_dir = Path(__file__).resolve().parent cur_dir = Path(__file__).resolve().parent
def test_online_plp(
opts: kaldifeat.PlpOptions,
wave: torch.Tensor,
cpu_features: torch.Tensor,
):
"""
Args:
opts:
The options to create the online plp extractor.
wave:
The input 1-D waveform.
cpu_features:
The groud truth features that are computed offline
"""
online_plp = kaldifeat.OnlinePlp(opts)
num_processed_frames = 0
i = 0 # current sample index to feed
while not online_plp.is_last_frame(num_processed_frames - 1):
while num_processed_frames < online_plp.num_frames_ready:
# There are new frames to be processed
frame = online_plp.get_frame(num_processed_frames)
assert torch.allclose(
frame.squeeze(0), cpu_features[num_processed_frames], atol=1e-3
)
num_processed_frames += 1
# Simulate streaming . Send a random number of audio samples
# to the extractor
num_samples = torch.randint(300, 1000, (1,)).item()
samples = wave[i : (i + num_samples)] # noqa
i += num_samples
if len(samples) == 0:
online_plp.input_finished()
continue
online_plp.accept_waveform(16000, samples)
assert num_processed_frames == online_plp.num_frames_ready
assert num_processed_frames == cpu_features.size(0)
def test_plp_default(): def test_plp_default():
print("=====test_plp_default=====") print("=====test_plp_default=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-plp.txt")
cpu_features = None
for device in get_devices(): for device in get_devices():
print("device", device) print("device", device)
opts = kaldifeat.PlpOptions() opts = kaldifeat.PlpOptions()
opts.frame_opts.dither = 0 opts.frame_opts.dither = 0
opts.device = device opts.device = device
plp = kaldifeat.Plp(opts) plp = kaldifeat.Plp(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device)
features = plp(wave) features = plp(wave.to(device))
gt = read_ark_txt(cur_dir / "test_data/test-plp.txt") if device.type == "cpu":
cpu_features = features
assert torch.allclose(features.cpu(), gt, rtol=1e-1) assert torch.allclose(features.cpu(), gt, rtol=1e-1)
opts = kaldifeat.PlpOptions()
opts.frame_opts.dither = 0
test_online_plp(opts, wave, cpu_features)
def test_plp_no_snip_edges(): def test_plp_no_snip_edges():
print("=====test_plp_no_snip_edges=====") print("=====test_plp_no_snip_edges=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-plp-no-snip-edges.txt")
cpu_features = None
for device in get_devices(): for device in get_devices():
print("device", device) print("device", device)
opts = kaldifeat.PlpOptions() opts = kaldifeat.PlpOptions()
@ -39,16 +97,26 @@ def test_plp_no_snip_edges():
opts.frame_opts.snip_edges = False opts.frame_opts.snip_edges = False
plp = kaldifeat.Plp(opts) plp = kaldifeat.Plp(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device)
features = plp(wave) features = plp(wave.to(device))
gt = read_ark_txt(cur_dir / "test_data/test-plp-no-snip-edges.txt") if device.type == "cpu":
cpu_features = features
assert torch.allclose(features.cpu(), gt, atol=1e-1) assert torch.allclose(features.cpu(), gt, atol=1e-1)
opts = kaldifeat.PlpOptions()
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
test_online_plp(opts, wave, cpu_features)
def test_plp_htk_10_ceps(): def test_plp_htk_10_ceps():
print("=====test_plp_htk_10_ceps=====") print("=====test_plp_htk_10_ceps=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-plp-htk-10-ceps.txt")
cpu_features = None
for device in get_devices(): for device in get_devices():
print("device", device) print("device", device)
opts = kaldifeat.PlpOptions() opts = kaldifeat.PlpOptions()
@ -58,13 +126,19 @@ def test_plp_htk_10_ceps():
opts.frame_opts.dither = 0 opts.frame_opts.dither = 0
plp = kaldifeat.Plp(opts) plp = kaldifeat.Plp(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device)
features = plp(wave) features = plp(wave.to(device))
gt = read_ark_txt(cur_dir / "test_data/test-plp-htk-10-ceps.txt") if device.type == "cpu":
cpu_features = features
assert torch.allclose(features.cpu(), gt, atol=1e-1) assert torch.allclose(features.cpu(), gt, atol=1e-1)
opts = kaldifeat.PlpOptions()
opts.htk_compat = True
opts.num_ceps = 10
opts.frame_opts.dither = 0
test_online_plp(opts, wave, cpu_features)
def test_pickle(): def test_pickle():
for device in get_devices(): for device in get_devices():
@ -79,6 +153,16 @@ def test_pickle():
assert str(plp.opts) == str(plp2.opts) assert str(plp.opts) == str(plp2.opts)
opts = kaldifeat.PlpOptions()
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
plp = kaldifeat.OnlinePlp(opts)
data = pickle.dumps(plp)
plp2 = pickle.loads(data)
assert str(plp.opts) == str(plp2.opts)
if __name__ == "__main__": if __name__ == "__main__":
test_plp_default() test_plp_default()