Merge pull request #1 from csukuangfj/fbank

Add `kaldifeat.Fbank`.
This commit is contained in:
Fangjun Kuang 2021-07-16 18:42:15 +08:00 committed by GitHub
commit 40a9906ad2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 323 additions and 37 deletions

View File

@ -1,5 +1,11 @@
[flake8]
max-line-length = 80
exclude =
.git,
build,
build_release,
kaldifeat/python/kaldifeat/__init__.py
ignore =
E402

1
.gitignore vendored
View File

@ -2,3 +2,4 @@ build/
build*/
*.egg-info*/
dist/
__pycache__/

View File

@ -4,7 +4,7 @@ cmake_minimum_required(VERSION 3.8 FATAL_ERROR)
project(kaldifeat)
set(kaldifeat_VERSION "0.0.1")
set(kaldifeat_VERSION "1.0")
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")

0
cmake/__init__.py Normal file
View File

83
cmake/cmake_extension.py Normal file
View File

@ -0,0 +1,83 @@
# Copyright (c) 2021 Xiaomi Corporation (author: Fangjun Kuang)
import glob
import os
import shutil
import sys
from pathlib import Path
import setuptools
from setuptools.command.build_ext import build_ext
try:
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
class bdist_wheel(_bdist_wheel):
def finalize_options(self):
_bdist_wheel.finalize_options(self)
# In this case, the generated wheel has a name in the form
# k2-xxx-pyxx-none-any.whl
# self.root_is_pure = True
# The generated wheel has a name ending with
# -linux_x86_64.whl
self.root_is_pure = False
except ImportError:
bdist_wheel = None
def cmake_extension(name, *args, **kwargs) -> setuptools.Extension:
kwargs["language"] = "c++"
sources = []
return setuptools.Extension(name, sources, *args, **kwargs)
class BuildExtension(build_ext):
def build_extension(self, ext: setuptools.extension.Extension):
# build/temp.linux-x86_64-3.8
os.makedirs(self.build_temp, exist_ok=True)
# build/lib.linux-x86_64-3.8
os.makedirs(self.build_lib, exist_ok=True)
kaldifeat_dir = Path(__file__).parent.parent.resolve()
cmake_args = os.environ.get("KALDIFEAT_CMAKE_ARGS", "")
make_args = os.environ.get("KALDIFEAT_MAKE_ARGS", "")
system_make_args = os.environ.get("MAKEFLAGS", "")
if cmake_args == "":
cmake_args = "-DCMAKE_BUILD_TYPE=Release"
if make_args == "" and system_make_args == "":
print("For fast compilation, run:")
print('export KALDIFEAT_MAKE_ARGS="-j"; python setup.py install')
if "PYTHON_EXECUTABLE" not in cmake_args:
print(f"Setting PYTHON_EXECUTABLE to {sys.executable}")
cmake_args += f" -DPYTHON_EXECUTABLE={sys.executable}"
build_cmd = f"""
cd {self.build_temp}
cmake {cmake_args} {kaldifeat_dir}
make {make_args} _kaldifeat
"""
print(f"build command is:\n{build_cmd}")
ret = os.system(build_cmd)
if ret != 0:
raise Exception(
"\nBuild kaldifeat failed. Please check the error message.\n"
"You can ask for help by creating an issue on GitHub.\n"
"\nClick:\n\thttps://github.com/csukuangfj/kaldifeat/issues/new\n" # noqa
)
lib_so = glob.glob(f"{self.build_temp}/lib/*kaldifeat*.so")
for so in lib_so:
print(f"Copying {so} to {self.build_lib}/")
shutil.copy(f"{so}", f"{self.build_lib}/")

View File

@ -9,6 +9,12 @@ set(kaldifeat_srcs
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
target_link_libraries(kaldifeat_core PUBLIC ${TORCH_LIBRARIES})
# PYTHON_INCLUDE_DIRS is set by pybind11
target_include_directories(kaldifeat_core PUBLIC ${PYTHON_INCLUDE_DIRS})
# PYTHON_LIBRARY is set by pybind11
target_link_libraries(kaldifeat_core PUBLIC ${PYTHON_LIBRARY})
add_executable(test_kaldifeat test_kaldifeat.cc)
target_link_libraries(test_kaldifeat PRIVATE kaldifeat_core)

View File

@ -8,12 +8,16 @@
#define KALDIFEAT_CSRC_FEATURE_FBANK_H_
#include <map>
#include <string>
#include "kaldifeat/csrc/feature-common.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/mel-computations.h"
#include "pybind11/pybind11.h"
#include "torch/torch.h"
namespace py = pybind11;
namespace kaldifeat {
struct FbankOptions {
@ -42,6 +46,16 @@ struct FbankOptions {
FbankOptions() : device("cpu") { mel_opts.num_bins = 23; }
// Get/Set methods are for implementing properties in Python
py::object GetDevice() const {
py::object ans = py::module_::import("torch").attr("device");
return ans(device.str());
}
void SetDevice(py::object obj) {
std::string s = static_cast<py::str>(obj);
device = torch::Device(s);
}
std::string ToString() const {
std::ostringstream os;
os << "frame_opts: \n";

View File

@ -19,22 +19,8 @@ void PybindFbankOptions(py::module &m) {
.def_readwrite("htk_compat", &FbankOptions::htk_compat)
.def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank)
.def_readwrite("use_power", &FbankOptions::use_power)
.def("set_device",
[](FbankOptions *fbank_opts, py::object device) {
std::string device_type =
static_cast<py::str>(device.attr("type"));
KALDIFEAT_ASSERT(device_type == "cpu" || device_type == "cuda")
<< "Unsupported device type: " << device_type;
auto index_attr = static_cast<py::object>(device.attr("index"));
int32_t device_index = 0;
if (!index_attr.is_none())
device_index = static_cast<py::int_>(index_attr);
if (device_type == "cpu")
fbank_opts->device = torch::Device("cpu");
else
fbank_opts->device = torch::Device(torch::kCUDA, device_index);
})
.def_property("device", &FbankOptions::GetDevice,
&FbankOptions::SetDevice)
.def("__str__", [](const FbankOptions &self) -> std::string {
return self.ToString();
});

View File

@ -27,7 +27,7 @@ PYBIND11_MODULE(_kaldifeat, m) {
PybindMelBanksOptions(m);
PybindFbankOptions(m);
m.def("compute", &Compute, py::arg("wave"), py::arg("fbank"));
m.def("compute_fbank_feats", &Compute, py::arg("wave"), py::arg("fbank"));
// It verifies that the reimplementation produces the same output
// as kaldi using default parameters with dither disabled.

View File

@ -0,0 +1,3 @@
from _kaldifeat import FbankOptions, FrameExtractionOptions, MelBanksOptions
from .fbank import Fbank

View File

@ -0,0 +1,82 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
from typing import List, Union
import _kaldifeat
import torch
import torch.nn as nn
class Fbank(nn.Module):
def __init__(self, opts: _kaldifeat.FbankOptions):
super().__init__()
self.opts = opts
self.computer = _kaldifeat.Fbank(opts)
def forward(
self, waves: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Compute the fbank features of a single waveform or
a list of waveforms.
Args:
waves:
A single 1-D tensor or a list of 1-D tensors. Each tensor contains
audio samples of a soundfile. To get a result compatible with Kaldi,
you should scale the samples to [-32768, 32767] before calling this
function. Note: You are not required to scale them if you don't care
about the compatibility with Kaldi.
Returns:
Return a list of 2-D tensors containing the fbank features if the
input is a list of 1-D tensors. The returned list has as many elements
as the input list.
Return a single 2-D tensor if the input is a single tensor.
"""
if isinstance(waves, list):
is_list = True
else:
waves = [waves]
is_list = False
num_frames_per_wave = [
_kaldifeat.num_frames(w.numel(), self.opts.frame_opts)
for w in waves
]
strided = [self.convert_samples_to_frames(w) for w in waves]
strided = torch.cat(strided, dim=0)
features = self.compute(strided)
if is_list:
return list(features.split(num_frames_per_wave))
else:
return features
def compute(self, x: torch.Tensor) -> torch.Tensor:
"""Compute fbank features given a 2-D tensor containing
frames data. Each row is a frame of size frame_lens, specified
in the fbank options.
Args:
x:
A 2-D tensor.
Returns:
Return a 2-D tensor with as many rows as the input tensor. Its
number of columns is the number mel bins.
"""
features = _kaldifeat.compute_fbank_feats(x, self.computer)
return features
def convert_samples_to_frames(self, wave: torch.Tensor) -> torch.Tensor:
"""Convert a 1-D tensor containing audio samples to a 2-D
tensor where each row is a frame of samples of size frame length
specified in the fbank options.
Args:
waves:
A 1-D tensor.
Returns:
Return a 2-D tensor.
"""
return _kaldifeat.get_strided(wave, self.opts.frame_opts)

View File

@ -11,6 +11,12 @@ if [ ! -f test.wav ]; then
sox -n -r 16000 -b 16 test.wav synth 1.2 sine 300-3300
fi
if [ ! -f test2.wav ]; then
# generate a wav of 0.5 seconds, containing a sine-wave
# swept from 300 Hz to 3300 Hz
sox -n -r 16000 -b 16 test2.wav synth 0.5 sine 300-3300
fi
echo "1 test.wav" > test.scp
# We disable dither for testing

Binary file not shown.

View File

@ -0,0 +1,72 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang)
import numpy as np
import soundfile as sf
import torch
import kaldifeat
def read_wave(filename) -> torch.Tensor:
"""Read a wave file and return it as a 1-D tensor.
Note:
You don't need to scale it to [-32768, 32767].
We use scaling here to follow the approach in Kaldi.
Args:
filename:
Filename of a sound file.
Returns:
Return a 1-D tensor containing audio samples.
"""
with sf.SoundFile(filename) as sf_desc:
sampling_rate = sf_desc.samplerate
assert sampling_rate == 16000
data = sf_desc.read(dtype=np.float32, always_2d=False)
data *= 32768
return torch.from_numpy(data)
def test_fbank():
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
wave0 = read_wave("test_data/test.wav")
wave1 = read_wave("test_data/test2.wav")
wave0 = wave0.to(device)
wave1 = wave1.to(device)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
opts.device = device
fbank = kaldifeat.Fbank(opts)
# We can compute fbank features in batches
features = fbank([wave0, wave1])
assert isinstance(features, list), f"{type(features)}"
assert len(features) == 2
# We can also compute fbank features for a single wave
features0 = fbank(wave0)
features1 = fbank(wave1)
assert torch.allclose(features[0], features0)
assert torch.allclose(features[1], features1)
# To compute fbank features for only a specified frame
audio_frames = fbank.convert_samples_to_frames(wave0)
feature_frame_1 = fbank.compute(audio_frames[1])
feature_frame_10 = fbank.compute(audio_frames[10])
assert torch.allclose(features0[1], feature_frame_1)
assert torch.allclose(features0[10], feature_frame_10)
if __name__ == "__main__":
test_fbank()

View File

@ -52,7 +52,7 @@ def test_and_benchmark_default_parameters():
for device in devices:
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.set_device(device)
fbank_opts.device = device
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave().to(device)
@ -74,14 +74,14 @@ def test_use_energy_htk_compat_true():
for device in devices:
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.set_device(device)
fbank_opts.device = device
fbank_opts.use_energy = True
fbank_opts.htk_compat = True
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave().to(device)
ans = _kaldifeat.compute(data, fbank)
ans = _kaldifeat.compute_fbank_feats(data, fbank)
expected = read_ark_txt("test-htk.txt")
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
@ -97,12 +97,12 @@ def test_use_energy_htk_compat_false():
fbank_opts.frame_opts.dither = 0
fbank_opts.use_energy = True
fbank_opts.htk_compat = False
fbank_opts.set_device(device)
fbank_opts.device = device
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave().to(device)
ans = _kaldifeat.compute(data, fbank)
ans = _kaldifeat.compute_fbank_feats(data, fbank)
expected = read_ark_txt("test-with-energy.txt")
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
@ -117,12 +117,12 @@ def test_40_mel():
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.mel_opts.num_bins = 40
fbank_opts.set_device(device)
fbank_opts.device = device
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave().to(device)
ans = _kaldifeat.compute(data, fbank)
ans = _kaldifeat.compute_fbank_feats(data, fbank)
expected = read_ark_txt("test-40.txt")
assert torch.allclose(ans.cpu(), expected, rtol=1e-1)
@ -138,12 +138,12 @@ def test_40_mel_no_snip_edges():
fbank_opts.frame_opts.snip_edges = False
fbank_opts.frame_opts.dither = 0
fbank_opts.mel_opts.num_bins = 40
fbank_opts.set_device(device)
fbank_opts.device = device
fbank = _kaldifeat.Fbank(fbank_opts)
data = read_wave().to(device)
ans = _kaldifeat.compute(data, fbank)
ans = _kaldifeat.compute_fbank_feats(data, fbank)
expected = read_ark_txt("test-40-no-snip-edges.txt")
assert torch.allclose(ans.cpu(), expected, rtol=1e-2)
@ -161,7 +161,7 @@ def test_compute_batch():
fbank_opts = _kaldifeat.FbankOptions()
fbank_opts.frame_opts.dither = 0
fbank_opts.frame_opts.snip_edges = False
fbank_opts.set_device(device)
fbank_opts.device = device
fbank = _kaldifeat.Fbank(fbank_opts)
def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]:
@ -175,7 +175,9 @@ def test_compute_batch():
]
strided = torch.cat(strided, dim=0)
features = _kaldifeat.compute(strided, fbank).split(num_frames)
features = _kaldifeat.compute_fbank_feats(strided, fbank).split(
num_frames
)
return features

View File

@ -52,6 +52,7 @@ def test_fbank_options():
opts.use_energy = False
opts.use_log_fbank = True
opts.use_power = True
opts.device = "cuda:0"
frame_opts.blackman_coeff = 0.42
frame_opts.dither = 1
@ -75,8 +76,8 @@ def test_fbank_options():
def main():
# test_frame_extraction_options()
# test_mel_banks_options()
test_frame_extraction_options()
test_mel_banks_options()
test_fbank_options()

1
requirements.txt Normal file
View File

@ -0,0 +1 @@
torch

View File

@ -5,6 +5,9 @@
import re
import setuptools
import torch
from cmake.cmake_extension import BuildExtension, bdist_wheel, cmake_extension
def read_long_description():
@ -22,24 +25,35 @@ def get_package_version():
return latest_version
def get_pytorch_version():
# if it is 1.7.1+cuda101, then strip +cuda101
return torch.__version__.split("+")[0]
install_requires = [
f"torch=={get_pytorch_version()}",
]
package_name = "kaldifeat"
with open("kaldifeat/python/kaldifeat/__init__.py", "a") as f:
f.write(f"__version__ = '{get_package_version()}'\n")
setuptools.setup(
name=package_name,
version=get_package_version(),
author="Fangjun Kuang",
author_email="csukuangfj@gmail.com",
data_files=[("", ["LICENSE", "README.md"])],
package_dir={
package_name: "kaldifeat/python/kaldifeat",
},
package_dir={package_name: "kaldifeat/python/kaldifeat"},
packages=[package_name],
install_requires=install_requires,
url="https://github.com/csukuangfj/kaldifeat",
long_description=read_long_description(),
long_description_content_type="text/markdown",
# ext_modules=[cmake_extension('_kaldifeat')],
# cmdclass={'build_ext': BuildExtension},
zip_safe=False,
ext_modules=[cmake_extension("_kaldifeat")],
cmdclass={"build_ext": BuildExtension, "bdist_wheel": bdist_wheel},
classifiers=[
"Programming Language :: C++",
"Programming Language :: Python",
@ -52,3 +66,12 @@ setuptools.setup(
python_requires=">=3.6.0",
license="Apache licensed, as found in the LICENSE file",
)
# remove the line __version__ from kaldifeat/python/kaldifeat/__init__.py
with open("kaldifeat/python/kaldifeat/__init__.py", "r") as f:
lines = f.readlines()
with open("kaldifeat/python/kaldifeat/__init__.py", "w") as f:
for line in lines:
if "__version__" not in line:
f.write(line)