Add kaldifeat.Fbank

This commit is contained in:
Fangjun Kuang 2021-07-16 17:35:23 +08:00
parent be8442fd68
commit e043afe3d6
13 changed files with 207 additions and 30 deletions

View File

@ -1,5 +1,9 @@
[flake8]
max-line-length = 80
exclude =
.git,
kaldifeat/python/kaldifeat/__init__.py
ignore =
E402

1
.gitignore vendored
View File

@ -2,3 +2,4 @@ build/
build*/
*.egg-info*/
dist/
__pycache__/

View File

@ -9,6 +9,12 @@ set(kaldifeat_srcs
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
target_link_libraries(kaldifeat_core PUBLIC ${TORCH_LIBRARIES})
# PYTHON_INCLUDE_DIRS is set by pybind11
target_include_directories(kaldifeat_core PUBLIC ${PYTHON_INCLUDE_DIRS})
# PYTHON_LIBRARY is set by pybind11
target_link_libraries(kaldifeat_core PUBLIC ${PYTHON_LIBRARY})
add_executable(test_kaldifeat test_kaldifeat.cc)
target_link_libraries(test_kaldifeat PRIVATE kaldifeat_core)

View File

@ -8,12 +8,16 @@
#define KALDIFEAT_CSRC_FEATURE_FBANK_H_
#include <map>
#include <string>
#include "kaldifeat/csrc/feature-common.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/mel-computations.h"
#include "pybind11/pybind11.h"
#include "torch/torch.h"
namespace py = pybind11;
namespace kaldifeat {
struct FbankOptions {
@ -42,6 +46,16 @@ struct FbankOptions {
FbankOptions() : device("cpu") { mel_opts.num_bins = 23; }
// Get/Set methods are for implementing properties in Python
py::object GetDevice() const {
py::object ans = py::module_::import("torch").attr("device");
return ans(device.str());
}
void SetDevice(py::object obj) {
std::string s = static_cast<py::str>(obj);
device = torch::Device(s);
}
std::string ToString() const {
std::ostringstream os;
os << "frame_opts: \n";

View File

@ -19,22 +19,8 @@ void PybindFbankOptions(py::module &m) {
.def_readwrite("htk_compat", &FbankOptions::htk_compat)
.def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank)
.def_readwrite("use_power", &FbankOptions::use_power)
.def("set_device",
[](FbankOptions *fbank_opts, py::object device) {
std::string device_type =
static_cast<py::str>(device.attr("type"));
KALDIFEAT_ASSERT(device_type == "cpu" || device_type == "cuda")
<< "Unsupported device type: " << device_type;
auto index_attr = static_cast<py::object>(device.attr("index"));
int32_t device_index = 0;
if (!index_attr.is_none())
device_index = static_cast<py::int_>(index_attr);
if (device_type == "cpu")
fbank_opts->device = torch::Device("cpu");
else
fbank_opts->device = torch::Device(torch::kCUDA, device_index);
})
.def_property("device", &FbankOptions::GetDevice,
&FbankOptions::SetDevice)
.def("__str__", [](const FbankOptions &self) -> std::string {
return self.ToString();
});

View File

@ -27,7 +27,7 @@ PYBIND11_MODULE(_kaldifeat, m) {
PybindMelBanksOptions(m);
PybindFbankOptions(m);
m.def("compute", &Compute, py::arg("wave"), py::arg("fbank"));
m.def("compute_fbank_feats", &Compute, py::arg("wave"), py::arg("fbank"));
// It verifies that the reimplementation produces the same output
// as kaldi using default parameters with dither disabled.

View File

@ -0,0 +1,3 @@
from _kaldifeat import FbankOptions, FrameExtractionOptions, MelBanksOptions
from .fbank import Fbank

View File

@ -0,0 +1,82 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
from typing import List, Union
import _kaldifeat
import torch
import torch.nn as nn
class Fbank(nn.Module):
def __init__(self, opts: _kaldifeat.FbankOptions):
super().__init__()
self.opts = opts
self.computer = _kaldifeat.Fbank(opts)
def forward(
self, waves: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Compute the fbank features of a single waveform or
a list of waveforms.
Args:
waves:
A single 1-D tensor or a list of 1-D tensors. Each tensor contains
audio samples of a soundfile. To get a result compatible with Kaldi,
you should scale the samples to [-32768, 32767] before calling this
function. Note: You are not required to scale them if you don't care
about the compatibility with Kaldi.
Returns:
Return a list of 2-D tensors containing the fbank features if the
input is a list of 1-D tensors. The returned list has as many elements
as the input list.
Return a single 2-D tensor if the input is a single tensor.
"""
if isinstance(waves, list):
is_list = True
else:
waves = [waves]
is_list = False
num_frames_per_wave = [
_kaldifeat.num_frames(w.numel(), self.opts.frame_opts)
for w in waves
]
strided = [self.convert_samples_to_frames(w) for w in waves]
strided = torch.cat(strided, dim=0)
features = self.compute(strided)
if is_list:
return list(features.split(num_frames_per_wave))
else:
return features
def compute(self, x: torch.Tensor) -> torch.Tensor:
"""Compute fbank features given a 2-D tensor containing
frames data. Each row is a frame of size frame_lens, specified
in the fbank options.
Args:
x:
A 2-D tensor.
Returns:
Return a 2-D tensor with as many rows as the input tensor. Its
number of columns is the number mel bins.
"""
features = _kaldifeat.compute_fbank_feats(x, self.computer)
return features
def convert_samples_to_frames(self, wave: torch.Tensor) -> torch.Tensor:
"""Convert a 1-D tensor containing audio samples to a 2-D
tensor where each row is a frame of samples of size frame length
specified in the fbank options.
Args:
waves:
A 1-D tensor.
Returns:
Return a 2-D tensor.
"""
return _kaldifeat.get_strided(wave, self.opts.frame_opts)

View File

@ -11,6 +11,12 @@ if [ ! -f test.wav ]; then
sox -n -r 16000 -b 16 test.wav synth 1.2 sine 300-3300
fi
if [ ! -f test2.wav ]; then
# generate a wav of 0.5 seconds, containing a sine-wave
# swept from 300 Hz to 3300 Hz
sox -n -r 16000 -b 16 test2.wav synth 0.5 sine 300-3300
fi
echo "1 test.wav" > test.scp
# We disable dither for testing

Binary file not shown.

View File

@ -0,0 +1,72 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang)
import numpy as np
import soundfile as sf
import torch
import kaldifeat
def read_wave(filename) -> torch.Tensor:
"""Read a wave file and return it as a 1-D tensor.
Note:
You don't need to scale it to [-32768, 32767].
We use scaling here to follow the approach in Kaldi.
Args:
filename:
Filename of a sound file.
Returns:
Return a 1-D tensor containing audio samples.
"""
with sf.SoundFile(filename) as sf_desc:
sampling_rate = sf_desc.samplerate
assert sampling_rate == 16000
data = sf_desc.read(dtype=np.float32, always_2d=False)
data *= 32768
return torch.from_numpy(data)
def test_fbank():
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
wave0 = read_wave("test_data/test.wav")
wave1 = read_wave("test_data/test2.wav")
wave0 = wave0.to(device)
wave1 = wave1.to(device)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.device = device
fbank = kaldifeat.Fbank(opts)
# We can compute fbank features in batches
features = fbank([wave0, wave1])
assert isinstance(features, list), f"{type(features)}"
assert len(features) == 2
# We can also compute fbank features for a single wave
features0 = fbank(wave0)
features1 = fbank(wave1)
assert torch.allclose(features[0], features0)
assert torch.allclose(features[1], features1)
# To compute fbank features for only a specified frame
audio_frames = fbank.convert_samples_to_frames(wave0)
feature_frame_1 = fbank.compute(audio_frames[1])
feature_frame_10 = fbank.compute(audio_frames[10])
assert torch.allclose(features0[1], feature_frame_1)
assert torch.allclose(features0[10], feature_frame_10)
if __name__ == "__main__":
test_fbank()

View File

@ -52,7 +52,7 @@ def test_and_benchmark_default_parameters():
for device in devices:
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.set_device(device)
fbank_opts.device = device
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave().to(device)
@ -74,14 +74,14 @@ def test_use_energy_htk_compat_true():
for device in devices:
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.set_device(device)
fbank_opts.device = device
fbank_opts.use_energy = True
fbank_opts.htk_compat = True
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave().to(device)
ans = _kaldifeat.compute(data, fbank)
ans = _kaldifeat.compute_fbank_feats(data, fbank)
expected = read_ark_txt("test-htk.txt")
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
@ -97,12 +97,12 @@ def test_use_energy_htk_compat_false():
fbank_opts.frame_opts.dither = 0
fbank_opts.use_energy = True
fbank_opts.htk_compat = False
fbank_opts.set_device(device)
fbank_opts.device = device
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave().to(device)
ans = _kaldifeat.compute(data, fbank)
ans = _kaldifeat.compute_fbank_feats(data, fbank)
expected = read_ark_txt("test-with-energy.txt")
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
@ -117,12 +117,12 @@ def test_40_mel():
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.mel_opts.num_bins = 40
fbank_opts.set_device(device)
fbank_opts.device = device
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave().to(device)
ans = _kaldifeat.compute(data, fbank)
ans = _kaldifeat.compute_fbank_feats(data, fbank)
expected = read_ark_txt("test-40.txt")
assert torch.allclose(ans.cpu(), expected, rtol=1e-1)
@ -138,12 +138,12 @@ def test_40_mel_no_snip_edges():
fbank_opts.frame_opts.snip_edges = False
fbank_opts.frame_opts.dither = 0
fbank_opts.mel_opts.num_bins = 40
fbank_opts.set_device(device)
fbank_opts.device = device
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave().to(device)
ans = _kaldifeat.compute(data, fbank)
ans = _kaldifeat.compute_fbank_feats(data, fbank)
expected = read_ark_txt("test-40-no-snip-edges.txt")
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
@ -161,7 +161,7 @@ def test_compute_batch():
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.frame_opts.snip_edges = False
fbank_opts.set_device(device)
fbank_opts.device = device
fbank = _kaldifeat.Fbank(fbank_opts)
def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]:
@ -175,7 +175,9 @@ def test_compute_batch():
]
strided = torch.cat(strided, dim=0)
features = _kaldifeat.compute(strided, fbank).split(num_frames)
features = _kaldifeat.compute_fbank_feats(strided, fbank).split(
num_frames
)
return features

View File

@ -52,6 +52,7 @@ def test_fbank_options():
opts.use_energy = False
opts.use_log_fbank = True
opts.use_power = True
opts.device = "cuda:0"
frame_opts.blackman_coeff = 0.42
frame_opts.dither = 1
@ -75,8 +76,8 @@ def test_fbank_options():
def main():
# test_frame_extraction_options()
# test_mel_banks_options()
test_frame_extraction_options()
test_mel_banks_options()
test_fbank_options()