mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 18:12:17 +00:00
Add OnlineMfcc Python APIs.
This commit is contained in:
parent
e59d05a45a
commit
8f03b654fc
@ -66,8 +66,9 @@ class OnlineFeatureInterface {
|
||||
features.push_back(std::move(f));
|
||||
}
|
||||
return features;
|
||||
|
||||
// return torch::cat(features, [>dim<] 0);
|
||||
#if 0
|
||||
return torch::cat(features, /*dim*/ 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
/// This would be called from the application, when you get more wave data.
|
||||
|
@ -4,8 +4,9 @@
|
||||
|
||||
#include "kaldifeat/python/csrc/online-feature.h"
|
||||
|
||||
#include "kaldifeat/csrc/online-feature.h"
|
||||
#include <string>
|
||||
|
||||
#include "kaldifeat/csrc/online-feature.h"
|
||||
namespace kaldifeat {
|
||||
|
||||
template <typename C>
|
||||
|
@ -9,6 +9,6 @@ from _kaldifeat import (
|
||||
)
|
||||
|
||||
from .fbank import Fbank, OnlineFbank
|
||||
from .mfcc import Mfcc
|
||||
from .mfcc import Mfcc, OnlineMfcc
|
||||
from .plp import Plp
|
||||
from .spectrogram import Spectrogram
|
||||
|
@ -20,4 +20,4 @@ class OnlineFbank(OnlineFeature):
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.opts = _kaldifeat.FbankOptions.from_dict(state)
|
||||
self.computer = _kaldifeat.Fbank(self.opts)
|
||||
self.computer = _kaldifeat.OnlineFbank(self.opts)
|
||||
|
@ -4,9 +4,20 @@
|
||||
import _kaldifeat
|
||||
|
||||
from .offline_feature import OfflineFeature
|
||||
from .online_feature import OnlineFeature
|
||||
|
||||
|
||||
class Mfcc(OfflineFeature):
|
||||
def __init__(self, opts: _kaldifeat.MfccOptions):
|
||||
super().__init__(opts)
|
||||
self.computer = _kaldifeat.Mfcc(opts)
|
||||
|
||||
|
||||
class OnlineMfcc(OnlineFeature):
|
||||
def __init__(self, opts: _kaldifeat.MfccOptions):
|
||||
super().__init__(opts)
|
||||
self.computer = _kaldifeat.OnlineMfcc(opts)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.opts = _kaldifeat.MfccOptions.from_dict(state)
|
||||
self.computer = _kaldifeat.OnlineMfcc(self.opts)
|
||||
|
@ -13,24 +13,82 @@ import kaldifeat
|
||||
cur_dir = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
def test_online_mfcc(
|
||||
opts: kaldifeat.MfccOptions,
|
||||
wave: torch.Tensor,
|
||||
cpu_features: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
opts:
|
||||
The options to create the online mfcc extractor.
|
||||
wave:
|
||||
The input 1-D waveform.
|
||||
cpu_features:
|
||||
The groud truth features that are computed offline
|
||||
"""
|
||||
online_mfcc = kaldifeat.OnlineMfcc(opts)
|
||||
|
||||
num_processed_frames = 0
|
||||
i = 0 # current sample index to feed
|
||||
while not online_mfcc.is_last_frame(num_processed_frames - 1):
|
||||
while num_processed_frames < online_mfcc.num_frames_ready:
|
||||
# There are new frames to be processed
|
||||
frame = online_mfcc.get_frame(num_processed_frames)
|
||||
assert torch.allclose(
|
||||
frame.squeeze(0), cpu_features[num_processed_frames], atol=1e-3
|
||||
)
|
||||
num_processed_frames += 1
|
||||
|
||||
# Simulate streaming . Send a random number of audio samples
|
||||
# to the extractor
|
||||
num_samples = torch.randint(300, 1000, (1,)).item()
|
||||
|
||||
samples = wave[i : (i + num_samples)] # noqa
|
||||
i += num_samples
|
||||
if len(samples) == 0:
|
||||
online_mfcc.input_finished()
|
||||
continue
|
||||
|
||||
online_mfcc.accept_waveform(16000, samples)
|
||||
|
||||
assert num_processed_frames == online_mfcc.num_frames_ready
|
||||
assert num_processed_frames == cpu_features.size(0)
|
||||
|
||||
|
||||
def test_mfcc_default():
|
||||
print("=====test_mfcc_default=====")
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-mfcc.txt")
|
||||
|
||||
cpu_features = None
|
||||
for device in get_devices():
|
||||
print("device", device)
|
||||
opts = kaldifeat.MfccOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
mfcc = kaldifeat.Mfcc(opts)
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename).to(device)
|
||||
|
||||
features = mfcc(wave)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-mfcc.txt")
|
||||
features = mfcc(wave.to(device))
|
||||
if device.type == "cpu":
|
||||
cpu_features = features
|
||||
|
||||
assert torch.allclose(features.cpu(), gt, atol=1e-1)
|
||||
|
||||
opts = kaldifeat.MfccOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
|
||||
test_online_mfcc(opts, wave, cpu_features)
|
||||
|
||||
|
||||
def test_mfcc_no_snip_edges():
|
||||
print("=====test_mfcc_no_snip_edges=====")
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-mfcc-no-snip-edges.txt")
|
||||
|
||||
cpu_features = None
|
||||
for device in get_devices():
|
||||
print("device", device)
|
||||
opts = kaldifeat.MfccOptions()
|
||||
@ -39,13 +97,19 @@ def test_mfcc_no_snip_edges():
|
||||
opts.frame_opts.snip_edges = False
|
||||
|
||||
mfcc = kaldifeat.Mfcc(opts)
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename).to(device)
|
||||
|
||||
features = mfcc(wave)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-mfcc-no-snip-edges.txt")
|
||||
features = mfcc(wave.to(device))
|
||||
if device.type == "cpu":
|
||||
cpu_features = features
|
||||
|
||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||
|
||||
opts = kaldifeat.MfccOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
|
||||
test_online_mfcc(opts, wave, cpu_features)
|
||||
|
||||
|
||||
def test_pickle():
|
||||
for device in get_devices():
|
||||
@ -60,6 +124,16 @@ def test_pickle():
|
||||
|
||||
assert str(mfcc.opts) == str(mfcc2.opts)
|
||||
|
||||
opts = kaldifeat.MfccOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
|
||||
mfcc = kaldifeat.OnlineMfcc(opts)
|
||||
data = pickle.dumps(mfcc)
|
||||
mfcc2 = pickle.loads(data)
|
||||
|
||||
assert str(mfcc.opts) == str(mfcc2.opts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mfcc_default()
|
||||
|
Loading…
x
Reference in New Issue
Block a user