mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 01:52:39 +00:00
142 lines
3.8 KiB
Python
Executable File
142 lines
3.8 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
|
|
|
import pickle
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from utils import get_devices, read_ark_txt, read_wave
|
|
|
|
import kaldifeat
|
|
|
|
cur_dir = Path(__file__).resolve().parent
|
|
|
|
|
|
def test_online_mfcc(
|
|
opts: kaldifeat.MfccOptions,
|
|
wave: torch.Tensor,
|
|
cpu_features: torch.Tensor,
|
|
):
|
|
"""
|
|
Args:
|
|
opts:
|
|
The options to create the online mfcc extractor.
|
|
wave:
|
|
The input 1-D waveform.
|
|
cpu_features:
|
|
The groud truth features that are computed offline
|
|
"""
|
|
online_mfcc = kaldifeat.OnlineMfcc(opts)
|
|
|
|
num_processed_frames = 0
|
|
i = 0 # current sample index to feed
|
|
while not online_mfcc.is_last_frame(num_processed_frames - 1):
|
|
while num_processed_frames < online_mfcc.num_frames_ready:
|
|
# There are new frames to be processed
|
|
frame = online_mfcc.get_frame(num_processed_frames)
|
|
assert torch.allclose(
|
|
frame.squeeze(0), cpu_features[num_processed_frames], atol=1e-3
|
|
)
|
|
num_processed_frames += 1
|
|
|
|
# Simulate streaming . Send a random number of audio samples
|
|
# to the extractor
|
|
num_samples = torch.randint(300, 1000, (1,)).item()
|
|
|
|
samples = wave[i : (i + num_samples)] # noqa
|
|
i += num_samples
|
|
if len(samples) == 0:
|
|
online_mfcc.input_finished()
|
|
continue
|
|
|
|
online_mfcc.accept_waveform(16000, samples)
|
|
|
|
assert num_processed_frames == online_mfcc.num_frames_ready
|
|
assert num_processed_frames == cpu_features.size(0)
|
|
|
|
|
|
def test_mfcc_default():
|
|
print("=====test_mfcc_default=====")
|
|
filename = cur_dir / "test_data/test.wav"
|
|
wave = read_wave(filename)
|
|
gt = read_ark_txt(cur_dir / "test_data/test-mfcc.txt")
|
|
|
|
cpu_features = None
|
|
for device in get_devices():
|
|
print("device", device)
|
|
opts = kaldifeat.MfccOptions()
|
|
opts.device = device
|
|
opts.frame_opts.dither = 0
|
|
mfcc = kaldifeat.Mfcc(opts)
|
|
|
|
features = mfcc(wave.to(device))
|
|
if device.type == "cpu":
|
|
cpu_features = features
|
|
|
|
assert torch.allclose(features.cpu(), gt, atol=1e-1)
|
|
|
|
opts = kaldifeat.MfccOptions()
|
|
opts.frame_opts.dither = 0
|
|
|
|
test_online_mfcc(opts, wave, cpu_features)
|
|
|
|
|
|
def test_mfcc_no_snip_edges():
|
|
print("=====test_mfcc_no_snip_edges=====")
|
|
filename = cur_dir / "test_data/test.wav"
|
|
wave = read_wave(filename)
|
|
gt = read_ark_txt(cur_dir / "test_data/test-mfcc-no-snip-edges.txt")
|
|
|
|
cpu_features = None
|
|
for device in get_devices():
|
|
print("device", device)
|
|
opts = kaldifeat.MfccOptions()
|
|
opts.device = device
|
|
opts.frame_opts.dither = 0
|
|
opts.frame_opts.snip_edges = False
|
|
|
|
mfcc = kaldifeat.Mfcc(opts)
|
|
|
|
features = mfcc(wave.to(device))
|
|
if device.type == "cpu":
|
|
cpu_features = features
|
|
|
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
|
|
|
opts = kaldifeat.MfccOptions()
|
|
opts.frame_opts.dither = 0
|
|
opts.frame_opts.snip_edges = False
|
|
|
|
test_online_mfcc(opts, wave, cpu_features)
|
|
|
|
|
|
def test_pickle():
|
|
for device in get_devices():
|
|
opts = kaldifeat.MfccOptions()
|
|
opts.device = device
|
|
opts.frame_opts.dither = 0
|
|
opts.frame_opts.snip_edges = False
|
|
|
|
mfcc = kaldifeat.Mfcc(opts)
|
|
data = pickle.dumps(mfcc)
|
|
mfcc2 = pickle.loads(data)
|
|
|
|
assert str(mfcc.opts) == str(mfcc2.opts)
|
|
|
|
opts = kaldifeat.MfccOptions()
|
|
opts.frame_opts.dither = 0
|
|
opts.frame_opts.snip_edges = False
|
|
|
|
mfcc = kaldifeat.OnlineMfcc(opts)
|
|
data = pickle.dumps(mfcc)
|
|
mfcc2 = pickle.loads(data)
|
|
|
|
assert str(mfcc.opts) == str(mfcc2.opts)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_mfcc_default()
|
|
test_mfcc_no_snip_edges()
|
|
test_pickle()
|