From da74f96e4d558efa8ae57e47dd4bbac088df5a50 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 1 Jun 2022 17:42:11 +0800 Subject: [PATCH] Check the runtime version of PyTorch before importing kaldifeat. --- .gitignore | 1 + CMakeLists.txt | 6 ++++++ kaldifeat/python/kaldifeat/__init__.py | 9 +++++++++ kaldifeat/python/kaldifeat/torch_version.py.in | 12 ++++++++++++ 4 files changed, 28 insertions(+) create mode 100644 kaldifeat/python/kaldifeat/torch_version.py.in diff --git a/.gitignore b/.gitignore index 52da5e5..d6c034b 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ dist/ __pycache__/ test-1hour.wav path.sh +torch_version.py diff --git a/CMakeLists.txt b/CMakeLists.txt index ea65477..f594460 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,5 +51,11 @@ if(kaldifeat_BUILD_TESTS) enable_testing() endif() +# TORCH_VERSION is defined in cmake/torch.cmake +configure_file( + ${CMAKE_SOURCE_DIR}/kaldifeat/python/kaldifeat/torch_version.py.in + ${CMAKE_SOURCE_DIR}/kaldifeat/python/kaldifeat/torch_version.py @ONLY +) + include_directories(${CMAKE_SOURCE_DIR}) add_subdirectory(kaldifeat) diff --git a/kaldifeat/python/kaldifeat/__init__.py b/kaldifeat/python/kaldifeat/__init__.py index ea39003..adf7d79 100644 --- a/kaldifeat/python/kaldifeat/__init__.py +++ b/kaldifeat/python/kaldifeat/__init__.py @@ -1,4 +1,13 @@ import torch + +from .torch_version import kaldifeat_torch_version + +if torch.__version__.split("+")[0] != kaldifeat_torch_version.split("+")[0]: + raise ImportError( + f"kaldifeat was built using PyTorch {kaldifeat_torch_version}\n" + f"But you are using PyTorch {torch.__version__} to run it" + ) + from _kaldifeat import ( FbankOptions, FrameExtractionOptions, diff --git a/kaldifeat/python/kaldifeat/torch_version.py.in b/kaldifeat/python/kaldifeat/torch_version.py.in new file mode 100644 index 0000000..e6365fa --- /dev/null +++ b/kaldifeat/python/kaldifeat/torch_version.py.in @@ -0,0 +1,12 @@ +# Auto generated by the toplevel CMakeLists.txt. +# +# DO NOT EDIT. + +# The torch version used to build kaldifeat. We will check it against the +# torch version that is used to run kaldifeat. If they are not the same, +# `import kaldifeat` will throw. +# +# Some example values are: +# - 1.10.0+cu102 +# - 1.5.0+cpu +kaldifeat_torch_version = "@TORCH_VERSION@"