diff --git a/.flake8 b/.flake8 index 3551e08..8285441 100644 --- a/.flake8 +++ b/.flake8 @@ -3,6 +3,8 @@ max-line-length = 80 exclude = .git, + build, + build_release, kaldifeat/python/kaldifeat/__init__.py ignore = diff --git a/CMakeLists.txt b/CMakeLists.txt index 54f34d7..daaab25 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ cmake_minimum_required(VERSION 3.8 FATAL_ERROR) project(kaldifeat) -set(kaldifeat_VERSION "0.0.1") +set(kaldifeat_VERSION "1.0") set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") diff --git a/cmake/__init__.py b/cmake/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cmake/cmake_extension.py b/cmake/cmake_extension.py new file mode 100644 index 0000000..9f2d879 --- /dev/null +++ b/cmake/cmake_extension.py @@ -0,0 +1,83 @@ +# Copyright (c) 2021 Xiaomi Corporation (author: Fangjun Kuang) + +import glob +import os +import shutil +import sys +from pathlib import Path + +import setuptools +from setuptools.command.build_ext import build_ext + +try: + from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + + class bdist_wheel(_bdist_wheel): + def finalize_options(self): + _bdist_wheel.finalize_options(self) + # In this case, the generated wheel has a name in the form + # k2-xxx-pyxx-none-any.whl + # self.root_is_pure = True + + # The generated wheel has a name ending with + # -linux_x86_64.whl + self.root_is_pure = False + + +except ImportError: + bdist_wheel = None + + +def cmake_extension(name, *args, **kwargs) -> setuptools.Extension: + kwargs["language"] = "c++" + sources = [] + return setuptools.Extension(name, sources, *args, **kwargs) + + +class BuildExtension(build_ext): + def build_extension(self, ext: setuptools.extension.Extension): + # build/temp.linux-x86_64-3.8 + os.makedirs(self.build_temp, exist_ok=True) + + # build/lib.linux-x86_64-3.8 + os.makedirs(self.build_lib, exist_ok=True) + + kaldifeat_dir = Path(__file__).parent.parent.resolve() + + cmake_args = os.environ.get("KALDIFEAT_CMAKE_ARGS", "") + make_args = os.environ.get("KALDIFEAT_MAKE_ARGS", "") + system_make_args = os.environ.get("MAKEFLAGS", "") + + if cmake_args == "": + cmake_args = "-DCMAKE_BUILD_TYPE=Release" + + if make_args == "" and system_make_args == "": + print("For fast compilation, run:") + print('export KALDIFEAT_MAKE_ARGS="-j"; python setup.py install') + + if "PYTHON_EXECUTABLE" not in cmake_args: + print(f"Setting PYTHON_EXECUTABLE to {sys.executable}") + cmake_args += f" -DPYTHON_EXECUTABLE={sys.executable}" + + build_cmd = f""" + cd {self.build_temp} + + cmake {cmake_args} {kaldifeat_dir} + + + make {make_args} _kaldifeat + """ + print(f"build command is:\n{build_cmd}") + + ret = os.system(build_cmd) + if ret != 0: + raise Exception( + "\nBuild kaldifeat failed. Please check the error message.\n" + "You can ask for help by creating an issue on GitHub.\n" + "\nClick:\n\thttps://github.com/csukuangfj/kaldifeat/issues/new\n" # noqa + ) + + lib_so = glob.glob(f"{self.build_temp}/lib/*kaldifeat*.so") + for so in lib_so: + print(f"Copying {so} to {self.build_lib}/") + shutil.copy(f"{so}", f"{self.build_lib}/") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..12c6d5d --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +torch diff --git a/setup.py b/setup.py index e2bc6b7..3f4fb5e 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,9 @@ import re import setuptools +import torch + +from cmake.cmake_extension import BuildExtension, bdist_wheel, cmake_extension def read_long_description(): @@ -22,24 +25,35 @@ def get_package_version(): return latest_version +def get_pytorch_version(): + # if it is 1.7.1+cuda101, then strip +cuda101 + return torch.__version__.split("+")[0] + + +install_requires = [ + f"torch=={get_pytorch_version()}", +] + + package_name = "kaldifeat" +with open("kaldifeat/python/kaldifeat/__init__.py", "a") as f: + f.write(f"__version__ = '{get_package_version()}'\n") + setuptools.setup( name=package_name, version=get_package_version(), author="Fangjun Kuang", author_email="csukuangfj@gmail.com", data_files=[("", ["LICENSE", "README.md"])], - package_dir={ - package_name: "kaldifeat/python/kaldifeat", - }, + package_dir={package_name: "kaldifeat/python/kaldifeat"}, packages=[package_name], + install_requires=install_requires, url="https://github.com/csukuangfj/kaldifeat", long_description=read_long_description(), long_description_content_type="text/markdown", - # ext_modules=[cmake_extension('_kaldifeat')], - # cmdclass={'build_ext': BuildExtension}, - zip_safe=False, + ext_modules=[cmake_extension("_kaldifeat")], + cmdclass={"build_ext": BuildExtension, "bdist_wheel": bdist_wheel}, classifiers=[ "Programming Language :: C++", "Programming Language :: Python", @@ -52,3 +66,12 @@ setuptools.setup( python_requires=">=3.6.0", license="Apache licensed, as found in the LICENSE file", ) + +# remove the line __version__ from kaldifeat/python/kaldifeat/__init__.py +with open("kaldifeat/python/kaldifeat/__init__.py", "r") as f: + lines = f.readlines() + +with open("kaldifeat/python/kaldifeat/__init__.py", "w") as f: + for line in lines: + if "__version__" not in line: + f.write(line)