From 0228b6a56dbeb7fbd2727c395023a3a77921e796 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 20 Aug 2021 20:21:59 +0800 Subject: [PATCH] Support PyTorch 1.6.0 --- CMakeLists.txt | 2 +- README.md | 17 ++++++++++++++++- cmake/torch.cmake | 12 ++++++++++++ kaldifeat/csrc/CMakeLists.txt | 3 +++ kaldifeat/csrc/feature-common.h | 16 ++++++++++++++++ kaldifeat/csrc/feature-fbank.cc | 14 +++++++++++++- kaldifeat/csrc/feature-mfcc.cc | 14 +++++++++++++- kaldifeat/csrc/feature-plp.cc | 14 +++++++++++++- kaldifeat/csrc/feature-spectrogram.cc | 16 ++++++++++++++-- scripts/conda/kaldifeat/meta.yaml | 2 +- 10 files changed, 102 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0f3f901..d418afe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,7 +6,7 @@ project(kaldifeat) # remember to change the version in # scripts/conda/kaldifeat/meta.yaml -set(kaldifeat_VERSION "1.5.4") +set(kaldifeat_VERSION "1.6") set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") diff --git a/README.md b/README.md index 140e421..736530c 100644 --- a/README.md +++ b/README.md @@ -208,11 +208,26 @@ for more examples. # Installation +## From conda + +Supported versions of Python, PyTorch, and CUDA toolkit are listed below: + +[![Supported Python versions](/doc/source/images/python-3.6_3.7_3.8-blue.svg)](/doc/source/images/python-3.6_3.7_3.8-blue.svg) +[![Supported PyTorch versions](/doc/source/images/pytorch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg)](/doc/source/images/pytorch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg) +[![Supported CUDA versions](/doc/source/images/cuda-10.1_10.2_11.0_11.1-orange.svg)](/doc/source/images/cuda-10.1_10.2_11.0_11.1-orange.svg) + +```bash +conda install -c kaldifeat -c pytorch -c conda-forge kaldifeat python=3.8 cudatoolkit=11.1 pytorch=1.8.1 +``` + +You can select the supported Python version, CUDA toolkit version and PyTorch version as you wish. + + ## From PyPi with pip You need to install PyTorch and CMake first. cmake 3.11 is known to work. Other cmake versions may also work. -PyTorch 1.7.1 and 1.8.1 are known to work. Other PyTorch versions may also work. +PyTorch 1.6.1 and above are known to work. Other PyTorch versions may also work. ```bash pip install -v kaldilm diff --git a/cmake/torch.cmake b/cmake/torch.cmake index c35210b..6c5829d 100644 --- a/cmake/torch.cmake +++ b/cmake/torch.cmake @@ -23,3 +23,15 @@ execute_process( ) message(STATUS "PyTorch version: ${TORCH_VERSION}") + +execute_process( + COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__.split('.')[0])" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE KALDIFEAT_TORCH_VERSION_MAJOR +) + +execute_process( + COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__.split('.')[1])" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE KALDIFEAT_TORCH_VERSION_MINOR +) diff --git a/kaldifeat/csrc/CMakeLists.txt b/kaldifeat/csrc/CMakeLists.txt index acd6c57..13b9f9b 100644 --- a/kaldifeat/csrc/CMakeLists.txt +++ b/kaldifeat/csrc/CMakeLists.txt @@ -14,6 +14,9 @@ set(kaldifeat_srcs add_library(kaldifeat_core SHARED ${kaldifeat_srcs}) target_link_libraries(kaldifeat_core PUBLIC ${TORCH_LIBRARIES}) +target_compile_definitions(kaldifeat_core PUBLIC KALDIFEAT_TORCH_VERSION_MAJOR=${KALDIFEAT_TORCH_VERSION_MAJOR}) +target_compile_definitions(kaldifeat_core PUBLIC KALDIFEAT_TORCH_VERSION_MINOR=${KALDIFEAT_TORCH_VERSION_MINOR}) + add_executable(test_kaldifeat test_kaldifeat.cc) target_link_libraries(test_kaldifeat PRIVATE kaldifeat_core) diff --git a/kaldifeat/csrc/feature-common.h b/kaldifeat/csrc/feature-common.h index 8d15dc9..5710c22 100644 --- a/kaldifeat/csrc/feature-common.h +++ b/kaldifeat/csrc/feature-common.h @@ -7,7 +7,23 @@ #ifndef KALDIFEAT_CSRC_FEATURE_COMMON_H_ #define KALDIFEAT_CSRC_FEATURE_COMMON_H_ +#include "kaldifeat/csrc/feature-functions.h" #include "kaldifeat/csrc/feature-window.h" +// See "The torch.fft module in PyTorch 1.7" +// https://github.com/pytorch/pytorch/wiki/The-torch.fft-module-in-PyTorch-1.7 +#if KALDIFEAT_TORCH_VERSION_MAJOR > 1 || \ + (KALDIFEAT_TORCH_VERSION_MAJOR == 1 && KALDIFEAT_TORCH_VERSION_MINOR > 6) +#include "torch/fft.h" +#define KALDIFEAT_HAS_FFT_NAMESPACE +// It uses torch::fft::rfft +// Its input shape is [x, N], output shape is [x, N/2] +// which is a complex tensor +#else +#include "ATen/Functions.h" +// It uses torch::fft +// Its input shape is [x, N], output shape is [x, N/2, 2] +// which contains the real part [..., ], and imaginary part [..., 1] +#endif namespace kaldifeat { diff --git a/kaldifeat/csrc/feature-fbank.cc b/kaldifeat/csrc/feature-fbank.cc index a5c1fd8..e740201 100644 --- a/kaldifeat/csrc/feature-fbank.cc +++ b/kaldifeat/csrc/feature-fbank.cc @@ -8,7 +8,6 @@ #include -#include "torch/fft.h" #include "torch/torch.h" namespace kaldifeat { @@ -66,7 +65,20 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy, } // note spectrum is in magnitude, not power, because of `abs()` +#if defined(KALDIFEAT_HAS_FFT_NAMESPACE) + // signal_frame shape: [x, 512] + // spectrum shape [x, 257 torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs(); +#else + // signal_frame shape [x, 512] + // real_imag shape [x, 257, 2], + // where [..., 0] is the real part + // [..., 1] is the imaginary part + torch::Tensor real_imag = torch::rfft(signal_frame, 1); + torch::Tensor real = real_imag.index({"...", 0}); + torch::Tensor imag = real_imag.index({"...", 1}); + torch::Tensor spectrum = (real.square() + imag.square()).sqrt(); +#endif // remove the last column, i.e., the highest fft bin spectrum = spectrum.index( diff --git a/kaldifeat/csrc/feature-mfcc.cc b/kaldifeat/csrc/feature-mfcc.cc index c233881..365d682 100644 --- a/kaldifeat/csrc/feature-mfcc.cc +++ b/kaldifeat/csrc/feature-mfcc.cc @@ -7,7 +7,6 @@ #include "kaldifeat/csrc/feature-mfcc.h" #include "kaldifeat/csrc/matrix-functions.h" -#include "torch/fft.h" namespace kaldifeat { @@ -92,7 +91,20 @@ torch::Tensor MfccComputer::Compute(torch::Tensor signal_raw_log_energy, } // note spectrum is in magnitude, not power, because of `abs()` +#if defined(KALDIFEAT_HAS_FFT_NAMESPACE) + // signal_frame shape: [x, 512] + // spectrum shape [x, 257 torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs(); +#else + // signal_frame shape [x, 512] + // real_imag shape [x, 257, 2], + // where [..., 0] is the real part + // [..., 1] is the imaginary part + torch::Tensor real_imag = torch::rfft(signal_frame, 1); + torch::Tensor real = real_imag.index({"...", 0}); + torch::Tensor imag = real_imag.index({"...", 1}); + torch::Tensor spectrum = (real.square() + imag.square()).sqrt(); +#endif // remove the last column, i.e., the highest fft bin spectrum = spectrum.index( diff --git a/kaldifeat/csrc/feature-plp.cc b/kaldifeat/csrc/feature-plp.cc index b41a151..6e0a077 100644 --- a/kaldifeat/csrc/feature-plp.cc +++ b/kaldifeat/csrc/feature-plp.cc @@ -7,7 +7,6 @@ #include "kaldifeat/csrc/feature-plp.h" #include "kaldifeat/csrc/feature-functions.h" -#include "torch/fft.h" #include "torch/torch.h" namespace kaldifeat { @@ -98,7 +97,20 @@ torch::Tensor PlpComputer::Compute(torch::Tensor signal_raw_log_energy, } // note spectrum is in magnitude, not power, because of `abs()` +#if defined(KALDIFEAT_HAS_FFT_NAMESPACE) + // signal_frame shape: [x, 512] + // spectrum shape [x, 257 torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs(); +#else + // signal_frame shape [x, 512] + // real_imag shape [x, 257, 2], + // where [..., 0] is the real part + // [..., 1] is the imaginary part + torch::Tensor real_imag = torch::rfft(signal_frame, 1); + torch::Tensor real = real_imag.index({"...", 0}); + torch::Tensor imag = real_imag.index({"...", 1}); + torch::Tensor spectrum = (real.square() + imag.square()).sqrt(); +#endif // remove the last column, i.e., the highest fft bin spectrum = spectrum.index( diff --git a/kaldifeat/csrc/feature-spectrogram.cc b/kaldifeat/csrc/feature-spectrogram.cc index d4d37e3..08f4f02 100644 --- a/kaldifeat/csrc/feature-spectrogram.cc +++ b/kaldifeat/csrc/feature-spectrogram.cc @@ -6,8 +6,6 @@ #include "kaldifeat/csrc/feature-spectrogram.h" -#include "torch/fft.h" - namespace kaldifeat { std::ostream &operator<<(std::ostream &os, const SpectrogramOptions &opts) { @@ -38,7 +36,21 @@ torch::Tensor SpectrogramComputer::Compute(torch::Tensor signal_raw_log_energy, } // note spectrum is in magnitude, not power, because of `abs()` +#if defined(KALDIFEAT_HAS_FFT_NAMESPACE) + // signal_frame shape: [x, 512] + // spectrum shape [x, 257 torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs(); +#else + // signal_frame shape [x, 512] + // real_imag shape [x, 257, 2], + // where [..., 0] is the real part + // [..., 1] is the imaginary part + torch::Tensor real_imag = torch::rfft(signal_frame, 1); + torch::Tensor real = real_imag.index({"...", 0}); + torch::Tensor imag = real_imag.index({"...", 1}); + torch::Tensor spectrum = (real.square() + imag.square()).sqrt(); +#endif + if (opts_.return_raw_fft) { KALDIFEAT_ERR << "return raw fft is not supported yet"; } diff --git a/scripts/conda/kaldifeat/meta.yaml b/scripts/conda/kaldifeat/meta.yaml index 3b1c39a..8420daa 100644 --- a/scripts/conda/kaldifeat/meta.yaml +++ b/scripts/conda/kaldifeat/meta.yaml @@ -1,6 +1,6 @@ package: name: kaldifeat - version: "1.5.4" + version: "1.6" source: path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"