mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 01:52:39 +00:00
Support PyTorch 1.6.0
This commit is contained in:
parent
4a1d08c1fa
commit
0228b6a56d
@ -6,7 +6,7 @@ project(kaldifeat)
|
||||
|
||||
# remember to change the version in
|
||||
# scripts/conda/kaldifeat/meta.yaml
|
||||
set(kaldifeat_VERSION "1.5.4")
|
||||
set(kaldifeat_VERSION "1.6")
|
||||
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||
|
17
README.md
17
README.md
@ -208,11 +208,26 @@ for more examples.
|
||||
|
||||
# Installation
|
||||
|
||||
## From conda
|
||||
|
||||
Supported versions of Python, PyTorch, and CUDA toolkit are listed below:
|
||||
|
||||
[](/doc/source/images/python-3.6_3.7_3.8-blue.svg)
|
||||
[](/doc/source/images/pytorch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg)
|
||||
[](/doc/source/images/cuda-10.1_10.2_11.0_11.1-orange.svg)
|
||||
|
||||
```bash
|
||||
conda install -c kaldifeat -c pytorch -c conda-forge kaldifeat python=3.8 cudatoolkit=11.1 pytorch=1.8.1
|
||||
```
|
||||
|
||||
You can select the supported Python version, CUDA toolkit version and PyTorch version as you wish.
|
||||
|
||||
|
||||
## From PyPi with pip
|
||||
|
||||
You need to install PyTorch and CMake first.
|
||||
cmake 3.11 is known to work. Other cmake versions may also work.
|
||||
PyTorch 1.7.1 and 1.8.1 are known to work. Other PyTorch versions may also work.
|
||||
PyTorch 1.6.1 and above are known to work. Other PyTorch versions may also work.
|
||||
|
||||
```bash
|
||||
pip install -v kaldilm
|
||||
|
@ -23,3 +23,15 @@ execute_process(
|
||||
)
|
||||
|
||||
message(STATUS "PyTorch version: ${TORCH_VERSION}")
|
||||
|
||||
execute_process(
|
||||
COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__.split('.')[0])"
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE KALDIFEAT_TORCH_VERSION_MAJOR
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__.split('.')[1])"
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE KALDIFEAT_TORCH_VERSION_MINOR
|
||||
)
|
||||
|
@ -14,6 +14,9 @@ set(kaldifeat_srcs
|
||||
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
|
||||
target_link_libraries(kaldifeat_core PUBLIC ${TORCH_LIBRARIES})
|
||||
|
||||
target_compile_definitions(kaldifeat_core PUBLIC KALDIFEAT_TORCH_VERSION_MAJOR=${KALDIFEAT_TORCH_VERSION_MAJOR})
|
||||
target_compile_definitions(kaldifeat_core PUBLIC KALDIFEAT_TORCH_VERSION_MINOR=${KALDIFEAT_TORCH_VERSION_MINOR})
|
||||
|
||||
add_executable(test_kaldifeat test_kaldifeat.cc)
|
||||
target_link_libraries(test_kaldifeat PRIVATE kaldifeat_core)
|
||||
|
||||
|
@ -7,7 +7,23 @@
|
||||
#ifndef KALDIFEAT_CSRC_FEATURE_COMMON_H_
|
||||
#define KALDIFEAT_CSRC_FEATURE_COMMON_H_
|
||||
|
||||
#include "kaldifeat/csrc/feature-functions.h"
|
||||
#include "kaldifeat/csrc/feature-window.h"
|
||||
// See "The torch.fft module in PyTorch 1.7"
|
||||
// https://github.com/pytorch/pytorch/wiki/The-torch.fft-module-in-PyTorch-1.7
|
||||
#if KALDIFEAT_TORCH_VERSION_MAJOR > 1 || \
|
||||
(KALDIFEAT_TORCH_VERSION_MAJOR == 1 && KALDIFEAT_TORCH_VERSION_MINOR > 6)
|
||||
#include "torch/fft.h"
|
||||
#define KALDIFEAT_HAS_FFT_NAMESPACE
|
||||
// It uses torch::fft::rfft
|
||||
// Its input shape is [x, N], output shape is [x, N/2]
|
||||
// which is a complex tensor
|
||||
#else
|
||||
#include "ATen/Functions.h"
|
||||
// It uses torch::fft
|
||||
// Its input shape is [x, N], output shape is [x, N/2, 2]
|
||||
// which contains the real part [..., ], and imaginary part [..., 1]
|
||||
#endif
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
|
@ -8,7 +8,6 @@
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "torch/fft.h"
|
||||
#include "torch/torch.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
@ -66,7 +65,20 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
|
||||
}
|
||||
|
||||
// note spectrum is in magnitude, not power, because of `abs()`
|
||||
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
|
||||
// signal_frame shape: [x, 512]
|
||||
// spectrum shape [x, 257
|
||||
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
|
||||
#else
|
||||
// signal_frame shape [x, 512]
|
||||
// real_imag shape [x, 257, 2],
|
||||
// where [..., 0] is the real part
|
||||
// [..., 1] is the imaginary part
|
||||
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
|
||||
torch::Tensor real = real_imag.index({"...", 0});
|
||||
torch::Tensor imag = real_imag.index({"...", 1});
|
||||
torch::Tensor spectrum = (real.square() + imag.square()).sqrt();
|
||||
#endif
|
||||
|
||||
// remove the last column, i.e., the highest fft bin
|
||||
spectrum = spectrum.index(
|
||||
|
@ -7,7 +7,6 @@
|
||||
#include "kaldifeat/csrc/feature-mfcc.h"
|
||||
|
||||
#include "kaldifeat/csrc/matrix-functions.h"
|
||||
#include "torch/fft.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
@ -92,7 +91,20 @@ torch::Tensor MfccComputer::Compute(torch::Tensor signal_raw_log_energy,
|
||||
}
|
||||
|
||||
// note spectrum is in magnitude, not power, because of `abs()`
|
||||
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
|
||||
// signal_frame shape: [x, 512]
|
||||
// spectrum shape [x, 257
|
||||
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
|
||||
#else
|
||||
// signal_frame shape [x, 512]
|
||||
// real_imag shape [x, 257, 2],
|
||||
// where [..., 0] is the real part
|
||||
// [..., 1] is the imaginary part
|
||||
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
|
||||
torch::Tensor real = real_imag.index({"...", 0});
|
||||
torch::Tensor imag = real_imag.index({"...", 1});
|
||||
torch::Tensor spectrum = (real.square() + imag.square()).sqrt();
|
||||
#endif
|
||||
|
||||
// remove the last column, i.e., the highest fft bin
|
||||
spectrum = spectrum.index(
|
||||
|
@ -7,7 +7,6 @@
|
||||
#include "kaldifeat/csrc/feature-plp.h"
|
||||
|
||||
#include "kaldifeat/csrc/feature-functions.h"
|
||||
#include "torch/fft.h"
|
||||
#include "torch/torch.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
@ -98,7 +97,20 @@ torch::Tensor PlpComputer::Compute(torch::Tensor signal_raw_log_energy,
|
||||
}
|
||||
|
||||
// note spectrum is in magnitude, not power, because of `abs()`
|
||||
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
|
||||
// signal_frame shape: [x, 512]
|
||||
// spectrum shape [x, 257
|
||||
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
|
||||
#else
|
||||
// signal_frame shape [x, 512]
|
||||
// real_imag shape [x, 257, 2],
|
||||
// where [..., 0] is the real part
|
||||
// [..., 1] is the imaginary part
|
||||
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
|
||||
torch::Tensor real = real_imag.index({"...", 0});
|
||||
torch::Tensor imag = real_imag.index({"...", 1});
|
||||
torch::Tensor spectrum = (real.square() + imag.square()).sqrt();
|
||||
#endif
|
||||
|
||||
// remove the last column, i.e., the highest fft bin
|
||||
spectrum = spectrum.index(
|
||||
|
@ -6,8 +6,6 @@
|
||||
|
||||
#include "kaldifeat/csrc/feature-spectrogram.h"
|
||||
|
||||
#include "torch/fft.h"
|
||||
|
||||
namespace kaldifeat {
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, const SpectrogramOptions &opts) {
|
||||
@ -38,7 +36,21 @@ torch::Tensor SpectrogramComputer::Compute(torch::Tensor signal_raw_log_energy,
|
||||
}
|
||||
|
||||
// note spectrum is in magnitude, not power, because of `abs()`
|
||||
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
|
||||
// signal_frame shape: [x, 512]
|
||||
// spectrum shape [x, 257
|
||||
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
|
||||
#else
|
||||
// signal_frame shape [x, 512]
|
||||
// real_imag shape [x, 257, 2],
|
||||
// where [..., 0] is the real part
|
||||
// [..., 1] is the imaginary part
|
||||
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
|
||||
torch::Tensor real = real_imag.index({"...", 0});
|
||||
torch::Tensor imag = real_imag.index({"...", 1});
|
||||
torch::Tensor spectrum = (real.square() + imag.square()).sqrt();
|
||||
#endif
|
||||
|
||||
if (opts_.return_raw_fft) {
|
||||
KALDIFEAT_ERR << "return raw fft is not supported yet";
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
package:
|
||||
name: kaldifeat
|
||||
version: "1.5.4"
|
||||
version: "1.6"
|
||||
|
||||
source:
|
||||
path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"
|
||||
|
Loading…
x
Reference in New Issue
Block a user