Check the runtime version of PyTorch before importing kaldifeat.

This commit is contained in:
Fangjun Kuang 2022-06-01 17:42:11 +08:00
parent fc78bc2be7
commit da74f96e4d
4 changed files with 28 additions and 0 deletions

1
.gitignore vendored
View File

@ -5,3 +5,4 @@ dist/
__pycache__/
test-1hour.wav
path.sh
torch_version.py

View File

@ -51,5 +51,11 @@ if(kaldifeat_BUILD_TESTS)
enable_testing()
endif()
# TORCH_VERSION is defined in cmake/torch.cmake
configure_file(
${CMAKE_SOURCE_DIR}/kaldifeat/python/kaldifeat/torch_version.py.in
${CMAKE_SOURCE_DIR}/kaldifeat/python/kaldifeat/torch_version.py @ONLY
)
include_directories(${CMAKE_SOURCE_DIR})
add_subdirectory(kaldifeat)

View File

@ -1,4 +1,13 @@
import torch
from .torch_version import kaldifeat_torch_version
if torch.__version__.split("+")[0] != kaldifeat_torch_version.split("+")[0]:
raise ImportError(
f"kaldifeat was built using PyTorch {kaldifeat_torch_version}\n"
f"But you are using PyTorch {torch.__version__} to run it"
)
from _kaldifeat import (
FbankOptions,
FrameExtractionOptions,

View File

@ -0,0 +1,12 @@
# Auto generated by the toplevel CMakeLists.txt.
#
# DO NOT EDIT.
# The torch version used to build kaldifeat. We will check it against the
# torch version that is used to run kaldifeat. If they are not the same,
# `import kaldifeat` will throw.
#
# Some example values are:
# - 1.10.0+cu102
# - 1.5.0+cpu
kaldifeat_torch_version = "@TORCH_VERSION@"