mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-10 18:42:17 +00:00
Add OnlinePlp Python APIs.
This commit is contained in:
parent
8f03b654fc
commit
e793159cc7
@ -10,5 +10,5 @@ from _kaldifeat import (
|
|||||||
|
|
||||||
from .fbank import Fbank, OnlineFbank
|
from .fbank import Fbank, OnlineFbank
|
||||||
from .mfcc import Mfcc, OnlineMfcc
|
from .mfcc import Mfcc, OnlineMfcc
|
||||||
from .plp import Plp
|
from .plp import OnlinePlp, Plp
|
||||||
from .spectrogram import Spectrogram
|
from .spectrogram import Spectrogram
|
||||||
|
@ -4,9 +4,20 @@
|
|||||||
import _kaldifeat
|
import _kaldifeat
|
||||||
|
|
||||||
from .offline_feature import OfflineFeature
|
from .offline_feature import OfflineFeature
|
||||||
|
from .online_feature import OnlineFeature
|
||||||
|
|
||||||
|
|
||||||
class Plp(OfflineFeature):
|
class Plp(OfflineFeature):
|
||||||
def __init__(self, opts: _kaldifeat.PlpOptions):
|
def __init__(self, opts: _kaldifeat.PlpOptions):
|
||||||
super().__init__(opts)
|
super().__init__(opts)
|
||||||
self.computer = _kaldifeat.Plp(opts)
|
self.computer = _kaldifeat.Plp(opts)
|
||||||
|
|
||||||
|
|
||||||
|
class OnlinePlp(OnlineFeature):
|
||||||
|
def __init__(self, opts: _kaldifeat.PlpOptions):
|
||||||
|
super().__init__(opts)
|
||||||
|
self.computer = _kaldifeat.OnlinePlp(opts)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.opts = _kaldifeat.PlpOptions.from_dict(state)
|
||||||
|
self.computer = _kaldifeat.OnlinePlp(self.opts)
|
||||||
|
@ -13,24 +13,82 @@ import kaldifeat
|
|||||||
cur_dir = Path(__file__).resolve().parent
|
cur_dir = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
|
||||||
|
def test_online_plp(
|
||||||
|
opts: kaldifeat.PlpOptions,
|
||||||
|
wave: torch.Tensor,
|
||||||
|
cpu_features: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
opts:
|
||||||
|
The options to create the online plp extractor.
|
||||||
|
wave:
|
||||||
|
The input 1-D waveform.
|
||||||
|
cpu_features:
|
||||||
|
The groud truth features that are computed offline
|
||||||
|
"""
|
||||||
|
online_plp = kaldifeat.OnlinePlp(opts)
|
||||||
|
|
||||||
|
num_processed_frames = 0
|
||||||
|
i = 0 # current sample index to feed
|
||||||
|
while not online_plp.is_last_frame(num_processed_frames - 1):
|
||||||
|
while num_processed_frames < online_plp.num_frames_ready:
|
||||||
|
# There are new frames to be processed
|
||||||
|
frame = online_plp.get_frame(num_processed_frames)
|
||||||
|
assert torch.allclose(
|
||||||
|
frame.squeeze(0), cpu_features[num_processed_frames], atol=1e-3
|
||||||
|
)
|
||||||
|
num_processed_frames += 1
|
||||||
|
|
||||||
|
# Simulate streaming . Send a random number of audio samples
|
||||||
|
# to the extractor
|
||||||
|
num_samples = torch.randint(300, 1000, (1,)).item()
|
||||||
|
|
||||||
|
samples = wave[i : (i + num_samples)] # noqa
|
||||||
|
i += num_samples
|
||||||
|
if len(samples) == 0:
|
||||||
|
online_plp.input_finished()
|
||||||
|
continue
|
||||||
|
|
||||||
|
online_plp.accept_waveform(16000, samples)
|
||||||
|
|
||||||
|
assert num_processed_frames == online_plp.num_frames_ready
|
||||||
|
assert num_processed_frames == cpu_features.size(0)
|
||||||
|
|
||||||
|
|
||||||
def test_plp_default():
|
def test_plp_default():
|
||||||
print("=====test_plp_default=====")
|
print("=====test_plp_default=====")
|
||||||
|
filename = cur_dir / "test_data/test.wav"
|
||||||
|
wave = read_wave(filename)
|
||||||
|
gt = read_ark_txt(cur_dir / "test_data/test-plp.txt")
|
||||||
|
|
||||||
|
cpu_features = None
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
opts = kaldifeat.PlpOptions()
|
opts = kaldifeat.PlpOptions()
|
||||||
opts.frame_opts.dither = 0
|
opts.frame_opts.dither = 0
|
||||||
opts.device = device
|
opts.device = device
|
||||||
plp = kaldifeat.Plp(opts)
|
plp = kaldifeat.Plp(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
|
||||||
wave = read_wave(filename).to(device)
|
|
||||||
|
|
||||||
features = plp(wave)
|
features = plp(wave.to(device))
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-plp.txt")
|
if device.type == "cpu":
|
||||||
|
cpu_features = features
|
||||||
|
|
||||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||||
|
|
||||||
|
opts = kaldifeat.PlpOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
|
||||||
|
test_online_plp(opts, wave, cpu_features)
|
||||||
|
|
||||||
|
|
||||||
def test_plp_no_snip_edges():
|
def test_plp_no_snip_edges():
|
||||||
print("=====test_plp_no_snip_edges=====")
|
print("=====test_plp_no_snip_edges=====")
|
||||||
|
filename = cur_dir / "test_data/test.wav"
|
||||||
|
wave = read_wave(filename)
|
||||||
|
gt = read_ark_txt(cur_dir / "test_data/test-plp-no-snip-edges.txt")
|
||||||
|
|
||||||
|
cpu_features = None
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
opts = kaldifeat.PlpOptions()
|
opts = kaldifeat.PlpOptions()
|
||||||
@ -39,16 +97,26 @@ def test_plp_no_snip_edges():
|
|||||||
opts.frame_opts.snip_edges = False
|
opts.frame_opts.snip_edges = False
|
||||||
|
|
||||||
plp = kaldifeat.Plp(opts)
|
plp = kaldifeat.Plp(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
|
||||||
wave = read_wave(filename).to(device)
|
|
||||||
|
|
||||||
features = plp(wave)
|
features = plp(wave.to(device))
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-plp-no-snip-edges.txt")
|
if device.type == "cpu":
|
||||||
|
cpu_features = features
|
||||||
assert torch.allclose(features.cpu(), gt, atol=1e-1)
|
assert torch.allclose(features.cpu(), gt, atol=1e-1)
|
||||||
|
|
||||||
|
opts = kaldifeat.PlpOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
|
||||||
|
test_online_plp(opts, wave, cpu_features)
|
||||||
|
|
||||||
|
|
||||||
def test_plp_htk_10_ceps():
|
def test_plp_htk_10_ceps():
|
||||||
print("=====test_plp_htk_10_ceps=====")
|
print("=====test_plp_htk_10_ceps=====")
|
||||||
|
filename = cur_dir / "test_data/test.wav"
|
||||||
|
wave = read_wave(filename)
|
||||||
|
gt = read_ark_txt(cur_dir / "test_data/test-plp-htk-10-ceps.txt")
|
||||||
|
|
||||||
|
cpu_features = None
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
opts = kaldifeat.PlpOptions()
|
opts = kaldifeat.PlpOptions()
|
||||||
@ -58,13 +126,19 @@ def test_plp_htk_10_ceps():
|
|||||||
opts.frame_opts.dither = 0
|
opts.frame_opts.dither = 0
|
||||||
|
|
||||||
plp = kaldifeat.Plp(opts)
|
plp = kaldifeat.Plp(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
|
||||||
wave = read_wave(filename).to(device)
|
|
||||||
|
|
||||||
features = plp(wave)
|
features = plp(wave.to(device))
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-plp-htk-10-ceps.txt")
|
if device.type == "cpu":
|
||||||
|
cpu_features = features
|
||||||
assert torch.allclose(features.cpu(), gt, atol=1e-1)
|
assert torch.allclose(features.cpu(), gt, atol=1e-1)
|
||||||
|
|
||||||
|
opts = kaldifeat.PlpOptions()
|
||||||
|
opts.htk_compat = True
|
||||||
|
opts.num_ceps = 10
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
|
||||||
|
test_online_plp(opts, wave, cpu_features)
|
||||||
|
|
||||||
|
|
||||||
def test_pickle():
|
def test_pickle():
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
@ -79,6 +153,16 @@ def test_pickle():
|
|||||||
|
|
||||||
assert str(plp.opts) == str(plp2.opts)
|
assert str(plp.opts) == str(plp2.opts)
|
||||||
|
|
||||||
|
opts = kaldifeat.PlpOptions()
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
|
||||||
|
plp = kaldifeat.OnlinePlp(opts)
|
||||||
|
data = pickle.dumps(plp)
|
||||||
|
plp2 = pickle.loads(data)
|
||||||
|
|
||||||
|
assert str(plp.opts) == str(plp2.opts)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_plp_default()
|
test_plp_default()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user