Support PyTorch 1.6.0

This commit is contained in:
Fangjun Kuang 2021-08-20 20:21:59 +08:00
parent 4a1d08c1fa
commit 0228b6a56d
10 changed files with 102 additions and 8 deletions

View File

@ -6,7 +6,7 @@ project(kaldifeat)
# remember to change the version in
# scripts/conda/kaldifeat/meta.yaml
set(kaldifeat_VERSION "1.5.4")
set(kaldifeat_VERSION "1.6")
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")

View File

@ -208,11 +208,26 @@ for more examples.
# Installation
## From conda
Supported versions of Python, PyTorch, and CUDA toolkit are listed below:
[![Supported Python versions](/doc/source/images/python-3.6_3.7_3.8-blue.svg)](/doc/source/images/python-3.6_3.7_3.8-blue.svg)
[![Supported PyTorch versions](/doc/source/images/pytorch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg)](/doc/source/images/pytorch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg)
[![Supported CUDA versions](/doc/source/images/cuda-10.1_10.2_11.0_11.1-orange.svg)](/doc/source/images/cuda-10.1_10.2_11.0_11.1-orange.svg)
```bash
conda install -c kaldifeat -c pytorch -c conda-forge kaldifeat python=3.8 cudatoolkit=11.1 pytorch=1.8.1
```
You can select the supported Python version, CUDA toolkit version and PyTorch version as you wish.
## From PyPi with pip
You need to install PyTorch and CMake first.
cmake 3.11 is known to work. Other cmake versions may also work.
PyTorch 1.7.1 and 1.8.1 are known to work. Other PyTorch versions may also work.
PyTorch 1.6.1 and above are known to work. Other PyTorch versions may also work.
```bash
pip install -v kaldilm

View File

@ -23,3 +23,15 @@ execute_process(
)
message(STATUS "PyTorch version: ${TORCH_VERSION}")
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__.split('.')[0])"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE KALDIFEAT_TORCH_VERSION_MAJOR
)
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__.split('.')[1])"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE KALDIFEAT_TORCH_VERSION_MINOR
)

View File

@ -14,6 +14,9 @@ set(kaldifeat_srcs
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
target_link_libraries(kaldifeat_core PUBLIC ${TORCH_LIBRARIES})
target_compile_definitions(kaldifeat_core PUBLIC KALDIFEAT_TORCH_VERSION_MAJOR=${KALDIFEAT_TORCH_VERSION_MAJOR})
target_compile_definitions(kaldifeat_core PUBLIC KALDIFEAT_TORCH_VERSION_MINOR=${KALDIFEAT_TORCH_VERSION_MINOR})
add_executable(test_kaldifeat test_kaldifeat.cc)
target_link_libraries(test_kaldifeat PRIVATE kaldifeat_core)

View File

@ -7,7 +7,23 @@
#ifndef KALDIFEAT_CSRC_FEATURE_COMMON_H_
#define KALDIFEAT_CSRC_FEATURE_COMMON_H_
#include "kaldifeat/csrc/feature-functions.h"
#include "kaldifeat/csrc/feature-window.h"
// See "The torch.fft module in PyTorch 1.7"
// https://github.com/pytorch/pytorch/wiki/The-torch.fft-module-in-PyTorch-1.7
#if KALDIFEAT_TORCH_VERSION_MAJOR > 1 || \
(KALDIFEAT_TORCH_VERSION_MAJOR == 1 && KALDIFEAT_TORCH_VERSION_MINOR > 6)
#include "torch/fft.h"
#define KALDIFEAT_HAS_FFT_NAMESPACE
// It uses torch::fft::rfft
// Its input shape is [x, N], output shape is [x, N/2]
// which is a complex tensor
#else
#include "ATen/Functions.h"
// It uses torch::fft
// Its input shape is [x, N], output shape is [x, N/2, 2]
// which contains the real part [..., ], and imaginary part [..., 1]
#endif
namespace kaldifeat {

View File

@ -8,7 +8,6 @@
#include <cmath>
#include "torch/fft.h"
#include "torch/torch.h"
namespace kaldifeat {
@ -66,7 +65,20 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
}
// note spectrum is in magnitude, not power, because of `abs()`
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
// signal_frame shape: [x, 512]
// spectrum shape [x, 257
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
#else
// signal_frame shape [x, 512]
// real_imag shape [x, 257, 2],
// where [..., 0] is the real part
// [..., 1] is the imaginary part
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
torch::Tensor real = real_imag.index({"...", 0});
torch::Tensor imag = real_imag.index({"...", 1});
torch::Tensor spectrum = (real.square() + imag.square()).sqrt();
#endif
// remove the last column, i.e., the highest fft bin
spectrum = spectrum.index(

View File

@ -7,7 +7,6 @@
#include "kaldifeat/csrc/feature-mfcc.h"
#include "kaldifeat/csrc/matrix-functions.h"
#include "torch/fft.h"
namespace kaldifeat {
@ -92,7 +91,20 @@ torch::Tensor MfccComputer::Compute(torch::Tensor signal_raw_log_energy,
}
// note spectrum is in magnitude, not power, because of `abs()`
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
// signal_frame shape: [x, 512]
// spectrum shape [x, 257
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
#else
// signal_frame shape [x, 512]
// real_imag shape [x, 257, 2],
// where [..., 0] is the real part
// [..., 1] is the imaginary part
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
torch::Tensor real = real_imag.index({"...", 0});
torch::Tensor imag = real_imag.index({"...", 1});
torch::Tensor spectrum = (real.square() + imag.square()).sqrt();
#endif
// remove the last column, i.e., the highest fft bin
spectrum = spectrum.index(

View File

@ -7,7 +7,6 @@
#include "kaldifeat/csrc/feature-plp.h"
#include "kaldifeat/csrc/feature-functions.h"
#include "torch/fft.h"
#include "torch/torch.h"
namespace kaldifeat {
@ -98,7 +97,20 @@ torch::Tensor PlpComputer::Compute(torch::Tensor signal_raw_log_energy,
}
// note spectrum is in magnitude, not power, because of `abs()`
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
// signal_frame shape: [x, 512]
// spectrum shape [x, 257
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
#else
// signal_frame shape [x, 512]
// real_imag shape [x, 257, 2],
// where [..., 0] is the real part
// [..., 1] is the imaginary part
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
torch::Tensor real = real_imag.index({"...", 0});
torch::Tensor imag = real_imag.index({"...", 1});
torch::Tensor spectrum = (real.square() + imag.square()).sqrt();
#endif
// remove the last column, i.e., the highest fft bin
spectrum = spectrum.index(

View File

@ -6,8 +6,6 @@
#include "kaldifeat/csrc/feature-spectrogram.h"
#include "torch/fft.h"
namespace kaldifeat {
std::ostream &operator<<(std::ostream &os, const SpectrogramOptions &opts) {
@ -38,7 +36,21 @@ torch::Tensor SpectrogramComputer::Compute(torch::Tensor signal_raw_log_energy,
}
// note spectrum is in magnitude, not power, because of `abs()`
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
// signal_frame shape: [x, 512]
// spectrum shape [x, 257
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
#else
// signal_frame shape [x, 512]
// real_imag shape [x, 257, 2],
// where [..., 0] is the real part
// [..., 1] is the imaginary part
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
torch::Tensor real = real_imag.index({"...", 0});
torch::Tensor imag = real_imag.index({"...", 1});
torch::Tensor spectrum = (real.square() + imag.square()).sqrt();
#endif
if (opts_.return_raw_fft) {
KALDIFEAT_ERR << "return raw fft is not supported yet";
}

View File

@ -1,6 +1,6 @@
package:
name: kaldifeat
version: "1.5.4"
version: "1.6"
source:
path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"