Add OnlineMfcc Python APIs.

This commit is contained in:
Fangjun Kuang 2022-04-02 20:18:47 +08:00
parent e59d05a45a
commit 8f03b654fc
6 changed files with 100 additions and 13 deletions

View File

@ -66,8 +66,9 @@ class OnlineFeatureInterface {
features.push_back(std::move(f)); features.push_back(std::move(f));
} }
return features; return features;
#if 0
// return torch::cat(features, [>dim<] 0); return torch::cat(features, /*dim*/ 0);
#endif
} }
/// This would be called from the application, when you get more wave data. /// This would be called from the application, when you get more wave data.

View File

@ -4,8 +4,9 @@
#include "kaldifeat/python/csrc/online-feature.h" #include "kaldifeat/python/csrc/online-feature.h"
#include "kaldifeat/csrc/online-feature.h" #include <string>
#include "kaldifeat/csrc/online-feature.h"
namespace kaldifeat { namespace kaldifeat {
template <typename C> template <typename C>

View File

@ -9,6 +9,6 @@ from _kaldifeat import (
) )
from .fbank import Fbank, OnlineFbank from .fbank import Fbank, OnlineFbank
from .mfcc import Mfcc from .mfcc import Mfcc, OnlineMfcc
from .plp import Plp from .plp import Plp
from .spectrogram import Spectrogram from .spectrogram import Spectrogram

View File

@ -20,4 +20,4 @@ class OnlineFbank(OnlineFeature):
def __setstate__(self, state): def __setstate__(self, state):
self.opts = _kaldifeat.FbankOptions.from_dict(state) self.opts = _kaldifeat.FbankOptions.from_dict(state)
self.computer = _kaldifeat.Fbank(self.opts) self.computer = _kaldifeat.OnlineFbank(self.opts)

View File

@ -4,9 +4,20 @@
import _kaldifeat import _kaldifeat
from .offline_feature import OfflineFeature from .offline_feature import OfflineFeature
from .online_feature import OnlineFeature
class Mfcc(OfflineFeature): class Mfcc(OfflineFeature):
def __init__(self, opts: _kaldifeat.MfccOptions): def __init__(self, opts: _kaldifeat.MfccOptions):
super().__init__(opts) super().__init__(opts)
self.computer = _kaldifeat.Mfcc(opts) self.computer = _kaldifeat.Mfcc(opts)
class OnlineMfcc(OnlineFeature):
def __init__(self, opts: _kaldifeat.MfccOptions):
super().__init__(opts)
self.computer = _kaldifeat.OnlineMfcc(opts)
def __setstate__(self, state):
self.opts = _kaldifeat.MfccOptions.from_dict(state)
self.computer = _kaldifeat.OnlineMfcc(self.opts)

View File

@ -13,24 +13,82 @@ import kaldifeat
cur_dir = Path(__file__).resolve().parent cur_dir = Path(__file__).resolve().parent
def test_online_mfcc(
opts: kaldifeat.MfccOptions,
wave: torch.Tensor,
cpu_features: torch.Tensor,
):
"""
Args:
opts:
The options to create the online mfcc extractor.
wave:
The input 1-D waveform.
cpu_features:
The groud truth features that are computed offline
"""
online_mfcc = kaldifeat.OnlineMfcc(opts)
num_processed_frames = 0
i = 0 # current sample index to feed
while not online_mfcc.is_last_frame(num_processed_frames - 1):
while num_processed_frames < online_mfcc.num_frames_ready:
# There are new frames to be processed
frame = online_mfcc.get_frame(num_processed_frames)
assert torch.allclose(
frame.squeeze(0), cpu_features[num_processed_frames], atol=1e-3
)
num_processed_frames += 1
# Simulate streaming . Send a random number of audio samples
# to the extractor
num_samples = torch.randint(300, 1000, (1,)).item()
samples = wave[i : (i + num_samples)] # noqa
i += num_samples
if len(samples) == 0:
online_mfcc.input_finished()
continue
online_mfcc.accept_waveform(16000, samples)
assert num_processed_frames == online_mfcc.num_frames_ready
assert num_processed_frames == cpu_features.size(0)
def test_mfcc_default(): def test_mfcc_default():
print("=====test_mfcc_default=====") print("=====test_mfcc_default=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-mfcc.txt")
cpu_features = None
for device in get_devices(): for device in get_devices():
print("device", device) print("device", device)
opts = kaldifeat.MfccOptions() opts = kaldifeat.MfccOptions()
opts.device = device opts.device = device
opts.frame_opts.dither = 0 opts.frame_opts.dither = 0
mfcc = kaldifeat.Mfcc(opts) mfcc = kaldifeat.Mfcc(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device)
features = mfcc(wave) features = mfcc(wave.to(device))
gt = read_ark_txt(cur_dir / "test_data/test-mfcc.txt") if device.type == "cpu":
cpu_features = features
assert torch.allclose(features.cpu(), gt, atol=1e-1) assert torch.allclose(features.cpu(), gt, atol=1e-1)
opts = kaldifeat.MfccOptions()
opts.frame_opts.dither = 0
test_online_mfcc(opts, wave, cpu_features)
def test_mfcc_no_snip_edges(): def test_mfcc_no_snip_edges():
print("=====test_mfcc_no_snip_edges=====") print("=====test_mfcc_no_snip_edges=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-mfcc-no-snip-edges.txt")
cpu_features = None
for device in get_devices(): for device in get_devices():
print("device", device) print("device", device)
opts = kaldifeat.MfccOptions() opts = kaldifeat.MfccOptions()
@ -39,13 +97,19 @@ def test_mfcc_no_snip_edges():
opts.frame_opts.snip_edges = False opts.frame_opts.snip_edges = False
mfcc = kaldifeat.Mfcc(opts) mfcc = kaldifeat.Mfcc(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device)
features = mfcc(wave) features = mfcc(wave.to(device))
gt = read_ark_txt(cur_dir / "test_data/test-mfcc-no-snip-edges.txt") if device.type == "cpu":
cpu_features = features
assert torch.allclose(features.cpu(), gt, rtol=1e-1) assert torch.allclose(features.cpu(), gt, rtol=1e-1)
opts = kaldifeat.MfccOptions()
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
test_online_mfcc(opts, wave, cpu_features)
def test_pickle(): def test_pickle():
for device in get_devices(): for device in get_devices():
@ -60,6 +124,16 @@ def test_pickle():
assert str(mfcc.opts) == str(mfcc2.opts) assert str(mfcc.opts) == str(mfcc2.opts)
opts = kaldifeat.MfccOptions()
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
mfcc = kaldifeat.OnlineMfcc(opts)
data = pickle.dumps(mfcc)
mfcc2 = pickle.loads(data)
assert str(mfcc.opts) == str(mfcc2.opts)
if __name__ == "__main__": if __name__ == "__main__":
test_mfcc_default() test_mfcc_default()