diff --git a/kaldifeat/csrc/online-feature-itf.h b/kaldifeat/csrc/online-feature-itf.h index 60af265..835e182 100644 --- a/kaldifeat/csrc/online-feature-itf.h +++ b/kaldifeat/csrc/online-feature-itf.h @@ -66,8 +66,9 @@ class OnlineFeatureInterface { features.push_back(std::move(f)); } return features; - - // return torch::cat(features, [>dim<] 0); +#if 0 + return torch::cat(features, /*dim*/ 0); +#endif } /// This would be called from the application, when you get more wave data. diff --git a/kaldifeat/python/csrc/online-feature.cc b/kaldifeat/python/csrc/online-feature.cc index 9592a96..13e4a4f 100644 --- a/kaldifeat/python/csrc/online-feature.cc +++ b/kaldifeat/python/csrc/online-feature.cc @@ -4,8 +4,9 @@ #include "kaldifeat/python/csrc/online-feature.h" -#include "kaldifeat/csrc/online-feature.h" +#include +#include "kaldifeat/csrc/online-feature.h" namespace kaldifeat { template diff --git a/kaldifeat/python/kaldifeat/__init__.py b/kaldifeat/python/kaldifeat/__init__.py index 57004b7..4b4fd50 100644 --- a/kaldifeat/python/kaldifeat/__init__.py +++ b/kaldifeat/python/kaldifeat/__init__.py @@ -9,6 +9,6 @@ from _kaldifeat import ( ) from .fbank import Fbank, OnlineFbank -from .mfcc import Mfcc +from .mfcc import Mfcc, OnlineMfcc from .plp import Plp from .spectrogram import Spectrogram diff --git a/kaldifeat/python/kaldifeat/fbank.py b/kaldifeat/python/kaldifeat/fbank.py index 275d1cf..45bc3ef 100644 --- a/kaldifeat/python/kaldifeat/fbank.py +++ b/kaldifeat/python/kaldifeat/fbank.py @@ -20,4 +20,4 @@ class OnlineFbank(OnlineFeature): def __setstate__(self, state): self.opts = _kaldifeat.FbankOptions.from_dict(state) - self.computer = _kaldifeat.Fbank(self.opts) + self.computer = _kaldifeat.OnlineFbank(self.opts) diff --git a/kaldifeat/python/kaldifeat/mfcc.py b/kaldifeat/python/kaldifeat/mfcc.py index fa1e225..f76f2f4 100644 --- a/kaldifeat/python/kaldifeat/mfcc.py +++ b/kaldifeat/python/kaldifeat/mfcc.py @@ -4,9 +4,20 @@ import _kaldifeat from .offline_feature import OfflineFeature +from .online_feature import OnlineFeature class Mfcc(OfflineFeature): def __init__(self, opts: _kaldifeat.MfccOptions): super().__init__(opts) self.computer = _kaldifeat.Mfcc(opts) + + +class OnlineMfcc(OnlineFeature): + def __init__(self, opts: _kaldifeat.MfccOptions): + super().__init__(opts) + self.computer = _kaldifeat.OnlineMfcc(opts) + + def __setstate__(self, state): + self.opts = _kaldifeat.MfccOptions.from_dict(state) + self.computer = _kaldifeat.OnlineMfcc(self.opts) diff --git a/kaldifeat/python/tests/test_mfcc.py b/kaldifeat/python/tests/test_mfcc.py index 33407b5..5665da4 100755 --- a/kaldifeat/python/tests/test_mfcc.py +++ b/kaldifeat/python/tests/test_mfcc.py @@ -13,24 +13,82 @@ import kaldifeat cur_dir = Path(__file__).resolve().parent +def test_online_mfcc( + opts: kaldifeat.MfccOptions, + wave: torch.Tensor, + cpu_features: torch.Tensor, +): + """ + Args: + opts: + The options to create the online mfcc extractor. + wave: + The input 1-D waveform. + cpu_features: + The groud truth features that are computed offline + """ + online_mfcc = kaldifeat.OnlineMfcc(opts) + + num_processed_frames = 0 + i = 0 # current sample index to feed + while not online_mfcc.is_last_frame(num_processed_frames - 1): + while num_processed_frames < online_mfcc.num_frames_ready: + # There are new frames to be processed + frame = online_mfcc.get_frame(num_processed_frames) + assert torch.allclose( + frame.squeeze(0), cpu_features[num_processed_frames], atol=1e-3 + ) + num_processed_frames += 1 + + # Simulate streaming . Send a random number of audio samples + # to the extractor + num_samples = torch.randint(300, 1000, (1,)).item() + + samples = wave[i : (i + num_samples)] # noqa + i += num_samples + if len(samples) == 0: + online_mfcc.input_finished() + continue + + online_mfcc.accept_waveform(16000, samples) + + assert num_processed_frames == online_mfcc.num_frames_ready + assert num_processed_frames == cpu_features.size(0) + + def test_mfcc_default(): print("=====test_mfcc_default=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-mfcc.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.MfccOptions() opts.device = device opts.frame_opts.dither = 0 mfcc = kaldifeat.Mfcc(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename).to(device) - features = mfcc(wave) - gt = read_ark_txt(cur_dir / "test_data/test-mfcc.txt") + features = mfcc(wave.to(device)) + if device.type == "cpu": + cpu_features = features + assert torch.allclose(features.cpu(), gt, atol=1e-1) + opts = kaldifeat.MfccOptions() + opts.frame_opts.dither = 0 + + test_online_mfcc(opts, wave, cpu_features) + def test_mfcc_no_snip_edges(): print("=====test_mfcc_no_snip_edges=====") + filename = cur_dir / "test_data/test.wav" + wave = read_wave(filename) + gt = read_ark_txt(cur_dir / "test_data/test-mfcc-no-snip-edges.txt") + + cpu_features = None for device in get_devices(): print("device", device) opts = kaldifeat.MfccOptions() @@ -39,13 +97,19 @@ def test_mfcc_no_snip_edges(): opts.frame_opts.snip_edges = False mfcc = kaldifeat.Mfcc(opts) - filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename).to(device) - features = mfcc(wave) - gt = read_ark_txt(cur_dir / "test_data/test-mfcc-no-snip-edges.txt") + features = mfcc(wave.to(device)) + if device.type == "cpu": + cpu_features = features + assert torch.allclose(features.cpu(), gt, rtol=1e-1) + opts = kaldifeat.MfccOptions() + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + + test_online_mfcc(opts, wave, cpu_features) + def test_pickle(): for device in get_devices(): @@ -60,6 +124,16 @@ def test_pickle(): assert str(mfcc.opts) == str(mfcc2.opts) + opts = kaldifeat.MfccOptions() + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + + mfcc = kaldifeat.OnlineMfcc(opts) + data = pickle.dumps(mfcc) + mfcc2 = pickle.loads(data) + + assert str(mfcc.opts) == str(mfcc2.opts) + if __name__ == "__main__": test_mfcc_default()