2022-04-02 20:03:42 +08:00

337 lines
9.6 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang)
import pickle
from pathlib import Path
import torch
from utils import get_devices, read_ark_txt, read_wave
import kaldifeat
cur_dir = Path(__file__).resolve().parent
def test_online_fbank(
opts: kaldifeat.FbankOptions,
wave: torch.Tensor,
cpu_features: torch.Tensor,
):
"""
Args:
opts:
The options to create the online fbank extractor.
wave:
The input 1-D waveform.
cpu_features:
The groud truth features that are computed offline
"""
online_fbank = kaldifeat.OnlineFbank(opts)
num_processed_frames = 0
i = 0 # current sample index to feed
while not online_fbank.is_last_frame(num_processed_frames - 1):
while num_processed_frames < online_fbank.num_frames_ready:
# There are new frames to be processed
frame = online_fbank.get_frame(num_processed_frames)
assert torch.allclose(
frame.squeeze(0), cpu_features[num_processed_frames]
)
num_processed_frames += 1
# Simulate streaming . Send a random number of audio samples
# to the extractor
num_samples = torch.randint(300, 1000, (1,)).item()
samples = wave[i : (i + num_samples)] # noqa
i += num_samples
if len(samples) == 0:
online_fbank.input_finished()
continue
online_fbank.accept_waveform(16000, samples)
assert num_processed_frames == online_fbank.num_frames_ready
assert num_processed_frames == cpu_features.size(0)
def test_fbank_default():
print("=====test_fbank_default=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test.txt")
cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
fbank = kaldifeat.Fbank(opts)
features = fbank(wave)
assert features.device.type == "cpu"
assert torch.allclose(features, gt, rtol=1e-1)
if cpu_features is None:
cpu_features = features
features = fbank(wave.to(device))
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
# Now for online fbank
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.frame_opts.max_feature_vectors = 100
test_online_fbank(opts, wave, cpu_features)
def test_fbank_htk():
print("=====test_fbank_htk=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.use_energy = True
opts.htk_compat = True
fbank = kaldifeat.Fbank(opts)
features = fbank(wave)
assert features.device.type == "cpu"
assert torch.allclose(features, gt, rtol=1e-1)
if cpu_features is None:
cpu_features = features
features = fbank(wave.to(device))
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.use_energy = True
opts.htk_compat = True
test_online_fbank(opts, wave, cpu_features)
def test_fbank_with_energy():
print("=====test_fbank_with_energy=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.use_energy = True
fbank = kaldifeat.Fbank(opts)
features = fbank(wave)
assert torch.allclose(features, gt, rtol=1e-1)
assert features.device.type == "cpu"
if cpu_features is None:
cpu_features = features
features = fbank(wave.to(device))
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.use_energy = True
test_online_fbank(opts, wave, cpu_features)
def test_fbank_40_bins():
print("=====test_fbank_40_bins=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.mel_opts.num_bins = 40
fbank = kaldifeat.Fbank(opts)
features = fbank(wave)
assert features.device.type == "cpu"
assert torch.allclose(features, gt, rtol=1e-1)
if cpu_features is None:
cpu_features = features
features = fbank(wave.to(device))
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.mel_opts.num_bins = 40
test_online_fbank(opts, wave, cpu_features)
def test_fbank_40_bins_no_snip_edges():
print("=====test_fbank_40_bins_no_snip_edges=====")
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename)
gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
cpu_features = None
for device in get_devices():
print("device", device)
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.mel_opts.num_bins = 40
opts.frame_opts.snip_edges = False
fbank = kaldifeat.Fbank(opts)
features = fbank(wave)
assert features.device.type == "cpu"
assert torch.allclose(features, gt, rtol=1e-1)
if cpu_features is None:
cpu_features = features
features = fbank(wave.to(device))
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.mel_opts.num_bins = 40
opts.frame_opts.snip_edges = False
test_online_fbank(opts, wave, cpu_features)
def test_fbank_chunk():
print("=====test_fbank_chunk=====")
filename = cur_dir / "test_data/test-1hour.wav"
if filename.is_file() is False:
print(
f"Please execute {cur_dir}/test_data/run.sh "
f"to generate {filename} before running tis test"
)
return
for device in get_devices():
print("device", device)
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.mel_opts.num_bins = 40
opts.frame_opts.snip_edges = False
fbank = kaldifeat.Fbank(opts)
wave = read_wave(filename)
# You can use
#
# $ watch -n 0.2 free -m
#
# to view memory consumption
#
# 100 frames per chunk
features1 = fbank(wave, chunk_size=100 * 10)
features2 = fbank(wave)
assert torch.allclose(features1, features2)
assert features1.device == features2.device
assert features1.device.type == "cpu"
if device.type == "cuda":
wave = wave.to(device)
features1 = fbank(wave, chunk_size=100 * 10)
features2 = fbank(wave)
assert torch.allclose(features1, features2)
assert features1.device == features2.device
assert features1.device == device
def test_fbank_batch():
print("=====test_fbank_batch=====")
for device in get_devices():
print("device", device)
wave0 = read_wave(cur_dir / "test_data/test.wav")
wave1 = read_wave(cur_dir / "test_data/test2.wav")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
fbank = kaldifeat.Fbank(opts)
features = fbank([wave0, wave1], chunk_size=10)
features0 = fbank(wave0)
features1 = fbank(wave1)
assert torch.allclose(features[0], features0)
assert torch.allclose(features[1], features1)
if device.type == "cuda":
wave0 = wave0.to(device)
wave1 = wave1.to(device)
features = fbank([wave0, wave1], chunk_size=10)
features0 = fbank(wave0)
features1 = fbank(wave1)
assert torch.allclose(features[0], features0)
assert torch.allclose(features[1], features1)
def test_pickle():
for device in get_devices():
opts = kaldifeat.FbankOptions()
opts.use_energy = True
opts.use_power = False
opts.device = device
fbank = kaldifeat.Fbank(opts)
data = pickle.dumps(fbank)
fbank2 = pickle.loads(data)
assert str(fbank.opts) == str(fbank2.opts)
opts = kaldifeat.FbankOptions()
opts.use_energy = True
opts.use_power = False
fbank = kaldifeat.OnlineFbank(opts)
data = pickle.dumps(fbank)
fbank2 = pickle.loads(data)
assert str(fbank.opts) == str(fbank2.opts)
if __name__ == "__main__":
test_fbank_default()
test_fbank_htk()
test_fbank_with_energy()
test_fbank_40_bins()
test_fbank_40_bins_no_snip_edges()
test_fbank_chunk()
test_fbank_batch()
test_pickle()