diff --git a/.flake8 b/.flake8 index 0e88669..8285441 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,11 @@ [flake8] max-line-length = 80 +exclude = + .git, + build, + build_release, + kaldifeat/python/kaldifeat/__init__.py + ignore = E402 diff --git a/.gitignore b/.gitignore index f9f78d8..c697d52 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ build/ build*/ *.egg-info*/ dist/ +__pycache__/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 54f34d7..daaab25 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ cmake_minimum_required(VERSION 3.8 FATAL_ERROR) project(kaldifeat) -set(kaldifeat_VERSION "0.0.1") +set(kaldifeat_VERSION "1.0") set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") diff --git a/cmake/__init__.py b/cmake/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cmake/cmake_extension.py b/cmake/cmake_extension.py new file mode 100644 index 0000000..9f2d879 --- /dev/null +++ b/cmake/cmake_extension.py @@ -0,0 +1,83 @@ +# Copyright (c) 2021 Xiaomi Corporation (author: Fangjun Kuang) + +import glob +import os +import shutil +import sys +from pathlib import Path + +import setuptools +from setuptools.command.build_ext import build_ext + +try: + from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + + class bdist_wheel(_bdist_wheel): + def finalize_options(self): + _bdist_wheel.finalize_options(self) + # In this case, the generated wheel has a name in the form + # k2-xxx-pyxx-none-any.whl + # self.root_is_pure = True + + # The generated wheel has a name ending with + # -linux_x86_64.whl + self.root_is_pure = False + + +except ImportError: + bdist_wheel = None + + +def cmake_extension(name, *args, **kwargs) -> setuptools.Extension: + kwargs["language"] = "c++" + sources = [] + return setuptools.Extension(name, sources, *args, **kwargs) + + +class BuildExtension(build_ext): + def build_extension(self, ext: setuptools.extension.Extension): + # build/temp.linux-x86_64-3.8 + os.makedirs(self.build_temp, exist_ok=True) + + # build/lib.linux-x86_64-3.8 + os.makedirs(self.build_lib, exist_ok=True) + + kaldifeat_dir = Path(__file__).parent.parent.resolve() + + cmake_args = os.environ.get("KALDIFEAT_CMAKE_ARGS", "") + make_args = os.environ.get("KALDIFEAT_MAKE_ARGS", "") + system_make_args = os.environ.get("MAKEFLAGS", "") + + if cmake_args == "": + cmake_args = "-DCMAKE_BUILD_TYPE=Release" + + if make_args == "" and system_make_args == "": + print("For fast compilation, run:") + print('export KALDIFEAT_MAKE_ARGS="-j"; python setup.py install') + + if "PYTHON_EXECUTABLE" not in cmake_args: + print(f"Setting PYTHON_EXECUTABLE to {sys.executable}") + cmake_args += f" -DPYTHON_EXECUTABLE={sys.executable}" + + build_cmd = f""" + cd {self.build_temp} + + cmake {cmake_args} {kaldifeat_dir} + + + make {make_args} _kaldifeat + """ + print(f"build command is:\n{build_cmd}") + + ret = os.system(build_cmd) + if ret != 0: + raise Exception( + "\nBuild kaldifeat failed. Please check the error message.\n" + "You can ask for help by creating an issue on GitHub.\n" + "\nClick:\n\thttps://github.com/csukuangfj/kaldifeat/issues/new\n" # noqa + ) + + lib_so = glob.glob(f"{self.build_temp}/lib/*kaldifeat*.so") + for so in lib_so: + print(f"Copying {so} to {self.build_lib}/") + shutil.copy(f"{so}", f"{self.build_lib}/") diff --git a/kaldifeat/csrc/CMakeLists.txt b/kaldifeat/csrc/CMakeLists.txt index eddab10..d94d43a 100644 --- a/kaldifeat/csrc/CMakeLists.txt +++ b/kaldifeat/csrc/CMakeLists.txt @@ -9,6 +9,12 @@ set(kaldifeat_srcs add_library(kaldifeat_core SHARED ${kaldifeat_srcs}) target_link_libraries(kaldifeat_core PUBLIC ${TORCH_LIBRARIES}) +# PYTHON_INCLUDE_DIRS is set by pybind11 +target_include_directories(kaldifeat_core PUBLIC ${PYTHON_INCLUDE_DIRS}) + +# PYTHON_LIBRARY is set by pybind11 +target_link_libraries(kaldifeat_core PUBLIC ${PYTHON_LIBRARY}) + add_executable(test_kaldifeat test_kaldifeat.cc) target_link_libraries(test_kaldifeat PRIVATE kaldifeat_core) diff --git a/kaldifeat/csrc/feature-fbank.h b/kaldifeat/csrc/feature-fbank.h index fc40359..80a3ba9 100644 --- a/kaldifeat/csrc/feature-fbank.h +++ b/kaldifeat/csrc/feature-fbank.h @@ -8,12 +8,16 @@ #define KALDIFEAT_CSRC_FEATURE_FBANK_H_ #include +#include #include "kaldifeat/csrc/feature-common.h" #include "kaldifeat/csrc/feature-window.h" #include "kaldifeat/csrc/mel-computations.h" +#include "pybind11/pybind11.h" #include "torch/torch.h" +namespace py = pybind11; + namespace kaldifeat { struct FbankOptions { @@ -42,6 +46,16 @@ struct FbankOptions { FbankOptions() : device("cpu") { mel_opts.num_bins = 23; } + // Get/Set methods are for implementing properties in Python + py::object GetDevice() const { + py::object ans = py::module_::import("torch").attr("device"); + return ans(device.str()); + } + void SetDevice(py::object obj) { + std::string s = static_cast(obj); + device = torch::Device(s); + } + std::string ToString() const { std::ostringstream os; os << "frame_opts: \n"; diff --git a/kaldifeat/python/csrc/feature-fbank.cc b/kaldifeat/python/csrc/feature-fbank.cc index 5122605..5f26a6b 100644 --- a/kaldifeat/python/csrc/feature-fbank.cc +++ b/kaldifeat/python/csrc/feature-fbank.cc @@ -19,22 +19,8 @@ void PybindFbankOptions(py::module &m) { .def_readwrite("htk_compat", &FbankOptions::htk_compat) .def_readwrite("use_log_fbank", &FbankOptions::use_log_fbank) .def_readwrite("use_power", &FbankOptions::use_power) - .def("set_device", - [](FbankOptions *fbank_opts, py::object device) { - std::string device_type = - static_cast(device.attr("type")); - KALDIFEAT_ASSERT(device_type == "cpu" || device_type == "cuda") - << "Unsupported device type: " << device_type; - - auto index_attr = static_cast(device.attr("index")); - int32_t device_index = 0; - if (!index_attr.is_none()) - device_index = static_cast(index_attr); - if (device_type == "cpu") - fbank_opts->device = torch::Device("cpu"); - else - fbank_opts->device = torch::Device(torch::kCUDA, device_index); - }) + .def_property("device", &FbankOptions::GetDevice, + &FbankOptions::SetDevice) .def("__str__", [](const FbankOptions &self) -> std::string { return self.ToString(); }); diff --git a/kaldifeat/python/csrc/kaldifeat.cc b/kaldifeat/python/csrc/kaldifeat.cc index 3398f8d..ca4bd79 100644 --- a/kaldifeat/python/csrc/kaldifeat.cc +++ b/kaldifeat/python/csrc/kaldifeat.cc @@ -27,7 +27,7 @@ PYBIND11_MODULE(_kaldifeat, m) { PybindMelBanksOptions(m); PybindFbankOptions(m); - m.def("compute", &Compute, py::arg("wave"), py::arg("fbank")); + m.def("compute_fbank_feats", &Compute, py::arg("wave"), py::arg("fbank")); // It verifies that the reimplementation produces the same output // as kaldi using default parameters with dither disabled. diff --git a/kaldifeat/python/kaldifeat/__init__.py b/kaldifeat/python/kaldifeat/__init__.py index e69de29..e177288 100644 --- a/kaldifeat/python/kaldifeat/__init__.py +++ b/kaldifeat/python/kaldifeat/__init__.py @@ -0,0 +1,3 @@ +from _kaldifeat import FbankOptions, FrameExtractionOptions, MelBanksOptions + +from .fbank import Fbank diff --git a/kaldifeat/python/kaldifeat/fbank.py b/kaldifeat/python/kaldifeat/fbank.py new file mode 100644 index 0000000..5196956 --- /dev/null +++ b/kaldifeat/python/kaldifeat/fbank.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +from typing import List, Union + +import _kaldifeat +import torch +import torch.nn as nn + + +class Fbank(nn.Module): + def __init__(self, opts: _kaldifeat.FbankOptions): + super().__init__() + + self.opts = opts + self.computer = _kaldifeat.Fbank(opts) + + def forward( + self, waves: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Compute the fbank features of a single waveform or + a list of waveforms. + + Args: + waves: + A single 1-D tensor or a list of 1-D tensors. Each tensor contains + audio samples of a soundfile. To get a result compatible with Kaldi, + you should scale the samples to [-32768, 32767] before calling this + function. Note: You are not required to scale them if you don't care + about the compatibility with Kaldi. + Returns: + Return a list of 2-D tensors containing the fbank features if the + input is a list of 1-D tensors. The returned list has as many elements + as the input list. + Return a single 2-D tensor if the input is a single tensor. + """ + if isinstance(waves, list): + is_list = True + else: + waves = [waves] + is_list = False + + num_frames_per_wave = [ + _kaldifeat.num_frames(w.numel(), self.opts.frame_opts) + for w in waves + ] + + strided = [self.convert_samples_to_frames(w) for w in waves] + strided = torch.cat(strided, dim=0) + + features = self.compute(strided) + + if is_list: + return list(features.split(num_frames_per_wave)) + else: + return features + + def compute(self, x: torch.Tensor) -> torch.Tensor: + """Compute fbank features given a 2-D tensor containing + frames data. Each row is a frame of size frame_lens, specified + in the fbank options. + Args: + x: + A 2-D tensor. + Returns: + Return a 2-D tensor with as many rows as the input tensor. Its + number of columns is the number mel bins. + """ + features = _kaldifeat.compute_fbank_feats(x, self.computer) + return features + + def convert_samples_to_frames(self, wave: torch.Tensor) -> torch.Tensor: + """Convert a 1-D tensor containing audio samples to a 2-D + tensor where each row is a frame of samples of size frame length + specified in the fbank options. + + Args: + waves: + A 1-D tensor. + Returns: + Return a 2-D tensor. + """ + return _kaldifeat.get_strided(wave, self.opts.frame_opts) diff --git a/kaldifeat/python/tests/test_data/run.sh b/kaldifeat/python/tests/test_data/run.sh index 53b3f39..361b657 100755 --- a/kaldifeat/python/tests/test_data/run.sh +++ b/kaldifeat/python/tests/test_data/run.sh @@ -11,6 +11,12 @@ if [ ! -f test.wav ]; then sox -n -r 16000 -b 16 test.wav synth 1.2 sine 300-3300 fi +if [ ! -f test2.wav ]; then + # generate a wav of 0.5 seconds, containing a sine-wave + # swept from 300 Hz to 3300 Hz + sox -n -r 16000 -b 16 test2.wav synth 0.5 sine 300-3300 +fi + echo "1 test.wav" > test.scp # We disable dither for testing diff --git a/kaldifeat/python/tests/test_data/test2.wav b/kaldifeat/python/tests/test_data/test2.wav new file mode 100644 index 0000000..0016f50 Binary files /dev/null and b/kaldifeat/python/tests/test_data/test2.wav differ diff --git a/kaldifeat/python/tests/test_fbank.py b/kaldifeat/python/tests/test_fbank.py new file mode 100755 index 0000000..0f39a1c --- /dev/null +++ b/kaldifeat/python/tests/test_fbank.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +import numpy as np +import soundfile as sf +import torch + +import kaldifeat + + +def read_wave(filename) -> torch.Tensor: + """Read a wave file and return it as a 1-D tensor. + + Note: + You don't need to scale it to [-32768, 32767]. + We use scaling here to follow the approach in Kaldi. + + Args: + filename: + Filename of a sound file. + Returns: + Return a 1-D tensor containing audio samples. + """ + with sf.SoundFile(filename) as sf_desc: + sampling_rate = sf_desc.samplerate + assert sampling_rate == 16000 + data = sf_desc.read(dtype=np.float32, always_2d=False) + data *= 32768 + return torch.from_numpy(data) + + +def test_fbank(): + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + wave0 = read_wave("test_data/test.wav") + wave1 = read_wave("test_data/test2.wav") + + wave0 = wave0.to(device) + wave1 = wave1.to(device) + + opts = kaldifeat.FbankOptions() + opts.frame_opts.dither = 0 + opts.device = device + + fbank = kaldifeat.Fbank(opts) + + # We can compute fbank features in batches + features = fbank([wave0, wave1]) + assert isinstance(features, list), f"{type(features)}" + assert len(features) == 2 + + # We can also compute fbank features for a single wave + features0 = fbank(wave0) + features1 = fbank(wave1) + + assert torch.allclose(features[0], features0) + assert torch.allclose(features[1], features1) + + # To compute fbank features for only a specified frame + audio_frames = fbank.convert_samples_to_frames(wave0) + feature_frame_1 = fbank.compute(audio_frames[1]) + feature_frame_10 = fbank.compute(audio_frames[10]) + + assert torch.allclose(features0[1], feature_frame_1) + assert torch.allclose(features0[10], feature_frame_10) + + +if __name__ == "__main__": + test_fbank() diff --git a/kaldifeat/python/tests/test_kaldifeat.py b/kaldifeat/python/tests/test_kaldifeat.py index c41712c..bcf14c8 100755 --- a/kaldifeat/python/tests/test_kaldifeat.py +++ b/kaldifeat/python/tests/test_kaldifeat.py @@ -52,7 +52,7 @@ def test_and_benchmark_default_parameters(): for device in devices: fbank_opts = _kaldifeat.FbankOptions() fbank_opts.frame_opts.dither = 0 - fbank_opts.set_device(device) + fbank_opts.device = device fbank = _kaldifeat.Fbank(fbank_opts) data = read_wave().to(device) @@ -74,14 +74,14 @@ def test_use_energy_htk_compat_true(): for device in devices: fbank_opts = _kaldifeat.FbankOptions() fbank_opts.frame_opts.dither = 0 - fbank_opts.set_device(device) + fbank_opts.device = device fbank_opts.use_energy = True fbank_opts.htk_compat = True fbank = _kaldifeat.Fbank(fbank_opts) data = read_wave().to(device) - ans = _kaldifeat.compute(data, fbank) + ans = _kaldifeat.compute_fbank_feats(data, fbank) expected = read_ark_txt("test-htk.txt") assert torch.allclose(ans.cpu(), expected, rtol=1e-2) @@ -97,12 +97,12 @@ def test_use_energy_htk_compat_false(): fbank_opts.frame_opts.dither = 0 fbank_opts.use_energy = True fbank_opts.htk_compat = False - fbank_opts.set_device(device) + fbank_opts.device = device fbank = _kaldifeat.Fbank(fbank_opts) data = read_wave().to(device) - ans = _kaldifeat.compute(data, fbank) + ans = _kaldifeat.compute_fbank_feats(data, fbank) expected = read_ark_txt("test-with-energy.txt") assert torch.allclose(ans.cpu(), expected, rtol=1e-2) @@ -117,12 +117,12 @@ def test_40_mel(): fbank_opts = _kaldifeat.FbankOptions() fbank_opts.frame_opts.dither = 0 fbank_opts.mel_opts.num_bins = 40 - fbank_opts.set_device(device) + fbank_opts.device = device fbank = _kaldifeat.Fbank(fbank_opts) data = read_wave().to(device) - ans = _kaldifeat.compute(data, fbank) + ans = _kaldifeat.compute_fbank_feats(data, fbank) expected = read_ark_txt("test-40.txt") assert torch.allclose(ans.cpu(), expected, rtol=1e-1) @@ -138,12 +138,12 @@ def test_40_mel_no_snip_edges(): fbank_opts.frame_opts.snip_edges = False fbank_opts.frame_opts.dither = 0 fbank_opts.mel_opts.num_bins = 40 - fbank_opts.set_device(device) + fbank_opts.device = device fbank = _kaldifeat.Fbank(fbank_opts) data = read_wave().to(device) - ans = _kaldifeat.compute(data, fbank) + ans = _kaldifeat.compute_fbank_feats(data, fbank) expected = read_ark_txt("test-40-no-snip-edges.txt") assert torch.allclose(ans.cpu(), expected, rtol=1e-2) @@ -161,7 +161,7 @@ def test_compute_batch(): fbank_opts = _kaldifeat.FbankOptions() fbank_opts.frame_opts.dither = 0 fbank_opts.frame_opts.snip_edges = False - fbank_opts.set_device(device) + fbank_opts.device = device fbank = _kaldifeat.Fbank(fbank_opts) def impl(waves: List[torch.Tensor]) -> List[torch.Tensor]: @@ -175,7 +175,9 @@ def test_compute_batch(): ] strided = torch.cat(strided, dim=0) - features = _kaldifeat.compute(strided, fbank).split(num_frames) + features = _kaldifeat.compute_fbank_feats(strided, fbank).split( + num_frames + ) return features diff --git a/kaldifeat/python/tests/test_options.py b/kaldifeat/python/tests/test_options.py index 41e660b..e4a01a7 100755 --- a/kaldifeat/python/tests/test_options.py +++ b/kaldifeat/python/tests/test_options.py @@ -52,6 +52,7 @@ def test_fbank_options(): opts.use_energy = False opts.use_log_fbank = True opts.use_power = True + opts.device = "cuda:0" frame_opts.blackman_coeff = 0.42 frame_opts.dither = 1 @@ -75,8 +76,8 @@ def test_fbank_options(): def main(): - # test_frame_extraction_options() - # test_mel_banks_options() + test_frame_extraction_options() + test_mel_banks_options() test_fbank_options() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..12c6d5d --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +torch diff --git a/setup.py b/setup.py index e2bc6b7..3f4fb5e 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,9 @@ import re import setuptools +import torch + +from cmake.cmake_extension import BuildExtension, bdist_wheel, cmake_extension def read_long_description(): @@ -22,24 +25,35 @@ def get_package_version(): return latest_version +def get_pytorch_version(): + # if it is 1.7.1+cuda101, then strip +cuda101 + return torch.__version__.split("+")[0] + + +install_requires = [ + f"torch=={get_pytorch_version()}", +] + + package_name = "kaldifeat" +with open("kaldifeat/python/kaldifeat/__init__.py", "a") as f: + f.write(f"__version__ = '{get_package_version()}'\n") + setuptools.setup( name=package_name, version=get_package_version(), author="Fangjun Kuang", author_email="csukuangfj@gmail.com", data_files=[("", ["LICENSE", "README.md"])], - package_dir={ - package_name: "kaldifeat/python/kaldifeat", - }, + package_dir={package_name: "kaldifeat/python/kaldifeat"}, packages=[package_name], + install_requires=install_requires, url="https://github.com/csukuangfj/kaldifeat", long_description=read_long_description(), long_description_content_type="text/markdown", - # ext_modules=[cmake_extension('_kaldifeat')], - # cmdclass={'build_ext': BuildExtension}, - zip_safe=False, + ext_modules=[cmake_extension("_kaldifeat")], + cmdclass={"build_ext": BuildExtension, "bdist_wheel": bdist_wheel}, classifiers=[ "Programming Language :: C++", "Programming Language :: Python", @@ -52,3 +66,12 @@ setuptools.setup( python_requires=">=3.6.0", license="Apache licensed, as found in the LICENSE file", ) + +# remove the line __version__ from kaldifeat/python/kaldifeat/__init__.py +with open("kaldifeat/python/kaldifeat/__init__.py", "r") as f: + lines = f.readlines() + +with open("kaldifeat/python/kaldifeat/__init__.py", "w") as f: + for line in lines: + if "__version__" not in line: + f.write(line)