mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 01:52:39 +00:00
Support PyTorch 1.6.0
This commit is contained in:
parent
4a1d08c1fa
commit
0228b6a56d
@ -6,7 +6,7 @@ project(kaldifeat)
|
|||||||
|
|
||||||
# remember to change the version in
|
# remember to change the version in
|
||||||
# scripts/conda/kaldifeat/meta.yaml
|
# scripts/conda/kaldifeat/meta.yaml
|
||||||
set(kaldifeat_VERSION "1.5.4")
|
set(kaldifeat_VERSION "1.6")
|
||||||
|
|
||||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||||
|
17
README.md
17
README.md
@ -208,11 +208,26 @@ for more examples.
|
|||||||
|
|
||||||
# Installation
|
# Installation
|
||||||
|
|
||||||
|
## From conda
|
||||||
|
|
||||||
|
Supported versions of Python, PyTorch, and CUDA toolkit are listed below:
|
||||||
|
|
||||||
|
[](/doc/source/images/python-3.6_3.7_3.8-blue.svg)
|
||||||
|
[](/doc/source/images/pytorch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg)
|
||||||
|
[](/doc/source/images/cuda-10.1_10.2_11.0_11.1-orange.svg)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda install -c kaldifeat -c pytorch -c conda-forge kaldifeat python=3.8 cudatoolkit=11.1 pytorch=1.8.1
|
||||||
|
```
|
||||||
|
|
||||||
|
You can select the supported Python version, CUDA toolkit version and PyTorch version as you wish.
|
||||||
|
|
||||||
|
|
||||||
## From PyPi with pip
|
## From PyPi with pip
|
||||||
|
|
||||||
You need to install PyTorch and CMake first.
|
You need to install PyTorch and CMake first.
|
||||||
cmake 3.11 is known to work. Other cmake versions may also work.
|
cmake 3.11 is known to work. Other cmake versions may also work.
|
||||||
PyTorch 1.7.1 and 1.8.1 are known to work. Other PyTorch versions may also work.
|
PyTorch 1.6.1 and above are known to work. Other PyTorch versions may also work.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -v kaldilm
|
pip install -v kaldilm
|
||||||
|
@ -23,3 +23,15 @@ execute_process(
|
|||||||
)
|
)
|
||||||
|
|
||||||
message(STATUS "PyTorch version: ${TORCH_VERSION}")
|
message(STATUS "PyTorch version: ${TORCH_VERSION}")
|
||||||
|
|
||||||
|
execute_process(
|
||||||
|
COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__.split('.')[0])"
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
|
OUTPUT_VARIABLE KALDIFEAT_TORCH_VERSION_MAJOR
|
||||||
|
)
|
||||||
|
|
||||||
|
execute_process(
|
||||||
|
COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__.split('.')[1])"
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
|
OUTPUT_VARIABLE KALDIFEAT_TORCH_VERSION_MINOR
|
||||||
|
)
|
||||||
|
@ -14,6 +14,9 @@ set(kaldifeat_srcs
|
|||||||
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
|
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
|
||||||
target_link_libraries(kaldifeat_core PUBLIC ${TORCH_LIBRARIES})
|
target_link_libraries(kaldifeat_core PUBLIC ${TORCH_LIBRARIES})
|
||||||
|
|
||||||
|
target_compile_definitions(kaldifeat_core PUBLIC KALDIFEAT_TORCH_VERSION_MAJOR=${KALDIFEAT_TORCH_VERSION_MAJOR})
|
||||||
|
target_compile_definitions(kaldifeat_core PUBLIC KALDIFEAT_TORCH_VERSION_MINOR=${KALDIFEAT_TORCH_VERSION_MINOR})
|
||||||
|
|
||||||
add_executable(test_kaldifeat test_kaldifeat.cc)
|
add_executable(test_kaldifeat test_kaldifeat.cc)
|
||||||
target_link_libraries(test_kaldifeat PRIVATE kaldifeat_core)
|
target_link_libraries(test_kaldifeat PRIVATE kaldifeat_core)
|
||||||
|
|
||||||
|
@ -7,7 +7,23 @@
|
|||||||
#ifndef KALDIFEAT_CSRC_FEATURE_COMMON_H_
|
#ifndef KALDIFEAT_CSRC_FEATURE_COMMON_H_
|
||||||
#define KALDIFEAT_CSRC_FEATURE_COMMON_H_
|
#define KALDIFEAT_CSRC_FEATURE_COMMON_H_
|
||||||
|
|
||||||
|
#include "kaldifeat/csrc/feature-functions.h"
|
||||||
#include "kaldifeat/csrc/feature-window.h"
|
#include "kaldifeat/csrc/feature-window.h"
|
||||||
|
// See "The torch.fft module in PyTorch 1.7"
|
||||||
|
// https://github.com/pytorch/pytorch/wiki/The-torch.fft-module-in-PyTorch-1.7
|
||||||
|
#if KALDIFEAT_TORCH_VERSION_MAJOR > 1 || \
|
||||||
|
(KALDIFEAT_TORCH_VERSION_MAJOR == 1 && KALDIFEAT_TORCH_VERSION_MINOR > 6)
|
||||||
|
#include "torch/fft.h"
|
||||||
|
#define KALDIFEAT_HAS_FFT_NAMESPACE
|
||||||
|
// It uses torch::fft::rfft
|
||||||
|
// Its input shape is [x, N], output shape is [x, N/2]
|
||||||
|
// which is a complex tensor
|
||||||
|
#else
|
||||||
|
#include "ATen/Functions.h"
|
||||||
|
// It uses torch::fft
|
||||||
|
// Its input shape is [x, N], output shape is [x, N/2, 2]
|
||||||
|
// which contains the real part [..., ], and imaginary part [..., 1]
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
|
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "torch/fft.h"
|
|
||||||
#include "torch/torch.h"
|
#include "torch/torch.h"
|
||||||
|
|
||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
@ -66,7 +65,20 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// note spectrum is in magnitude, not power, because of `abs()`
|
// note spectrum is in magnitude, not power, because of `abs()`
|
||||||
|
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
|
||||||
|
// signal_frame shape: [x, 512]
|
||||||
|
// spectrum shape [x, 257
|
||||||
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
|
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
|
||||||
|
#else
|
||||||
|
// signal_frame shape [x, 512]
|
||||||
|
// real_imag shape [x, 257, 2],
|
||||||
|
// where [..., 0] is the real part
|
||||||
|
// [..., 1] is the imaginary part
|
||||||
|
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
|
||||||
|
torch::Tensor real = real_imag.index({"...", 0});
|
||||||
|
torch::Tensor imag = real_imag.index({"...", 1});
|
||||||
|
torch::Tensor spectrum = (real.square() + imag.square()).sqrt();
|
||||||
|
#endif
|
||||||
|
|
||||||
// remove the last column, i.e., the highest fft bin
|
// remove the last column, i.e., the highest fft bin
|
||||||
spectrum = spectrum.index(
|
spectrum = spectrum.index(
|
||||||
|
@ -7,7 +7,6 @@
|
|||||||
#include "kaldifeat/csrc/feature-mfcc.h"
|
#include "kaldifeat/csrc/feature-mfcc.h"
|
||||||
|
|
||||||
#include "kaldifeat/csrc/matrix-functions.h"
|
#include "kaldifeat/csrc/matrix-functions.h"
|
||||||
#include "torch/fft.h"
|
|
||||||
|
|
||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
|
|
||||||
@ -92,7 +91,20 @@ torch::Tensor MfccComputer::Compute(torch::Tensor signal_raw_log_energy,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// note spectrum is in magnitude, not power, because of `abs()`
|
// note spectrum is in magnitude, not power, because of `abs()`
|
||||||
|
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
|
||||||
|
// signal_frame shape: [x, 512]
|
||||||
|
// spectrum shape [x, 257
|
||||||
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
|
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
|
||||||
|
#else
|
||||||
|
// signal_frame shape [x, 512]
|
||||||
|
// real_imag shape [x, 257, 2],
|
||||||
|
// where [..., 0] is the real part
|
||||||
|
// [..., 1] is the imaginary part
|
||||||
|
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
|
||||||
|
torch::Tensor real = real_imag.index({"...", 0});
|
||||||
|
torch::Tensor imag = real_imag.index({"...", 1});
|
||||||
|
torch::Tensor spectrum = (real.square() + imag.square()).sqrt();
|
||||||
|
#endif
|
||||||
|
|
||||||
// remove the last column, i.e., the highest fft bin
|
// remove the last column, i.e., the highest fft bin
|
||||||
spectrum = spectrum.index(
|
spectrum = spectrum.index(
|
||||||
|
@ -7,7 +7,6 @@
|
|||||||
#include "kaldifeat/csrc/feature-plp.h"
|
#include "kaldifeat/csrc/feature-plp.h"
|
||||||
|
|
||||||
#include "kaldifeat/csrc/feature-functions.h"
|
#include "kaldifeat/csrc/feature-functions.h"
|
||||||
#include "torch/fft.h"
|
|
||||||
#include "torch/torch.h"
|
#include "torch/torch.h"
|
||||||
|
|
||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
@ -98,7 +97,20 @@ torch::Tensor PlpComputer::Compute(torch::Tensor signal_raw_log_energy,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// note spectrum is in magnitude, not power, because of `abs()`
|
// note spectrum is in magnitude, not power, because of `abs()`
|
||||||
|
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
|
||||||
|
// signal_frame shape: [x, 512]
|
||||||
|
// spectrum shape [x, 257
|
||||||
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
|
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
|
||||||
|
#else
|
||||||
|
// signal_frame shape [x, 512]
|
||||||
|
// real_imag shape [x, 257, 2],
|
||||||
|
// where [..., 0] is the real part
|
||||||
|
// [..., 1] is the imaginary part
|
||||||
|
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
|
||||||
|
torch::Tensor real = real_imag.index({"...", 0});
|
||||||
|
torch::Tensor imag = real_imag.index({"...", 1});
|
||||||
|
torch::Tensor spectrum = (real.square() + imag.square()).sqrt();
|
||||||
|
#endif
|
||||||
|
|
||||||
// remove the last column, i.e., the highest fft bin
|
// remove the last column, i.e., the highest fft bin
|
||||||
spectrum = spectrum.index(
|
spectrum = spectrum.index(
|
||||||
|
@ -6,8 +6,6 @@
|
|||||||
|
|
||||||
#include "kaldifeat/csrc/feature-spectrogram.h"
|
#include "kaldifeat/csrc/feature-spectrogram.h"
|
||||||
|
|
||||||
#include "torch/fft.h"
|
|
||||||
|
|
||||||
namespace kaldifeat {
|
namespace kaldifeat {
|
||||||
|
|
||||||
std::ostream &operator<<(std::ostream &os, const SpectrogramOptions &opts) {
|
std::ostream &operator<<(std::ostream &os, const SpectrogramOptions &opts) {
|
||||||
@ -38,7 +36,21 @@ torch::Tensor SpectrogramComputer::Compute(torch::Tensor signal_raw_log_energy,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// note spectrum is in magnitude, not power, because of `abs()`
|
// note spectrum is in magnitude, not power, because of `abs()`
|
||||||
|
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
|
||||||
|
// signal_frame shape: [x, 512]
|
||||||
|
// spectrum shape [x, 257
|
||||||
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
|
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
|
||||||
|
#else
|
||||||
|
// signal_frame shape [x, 512]
|
||||||
|
// real_imag shape [x, 257, 2],
|
||||||
|
// where [..., 0] is the real part
|
||||||
|
// [..., 1] is the imaginary part
|
||||||
|
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
|
||||||
|
torch::Tensor real = real_imag.index({"...", 0});
|
||||||
|
torch::Tensor imag = real_imag.index({"...", 1});
|
||||||
|
torch::Tensor spectrum = (real.square() + imag.square()).sqrt();
|
||||||
|
#endif
|
||||||
|
|
||||||
if (opts_.return_raw_fft) {
|
if (opts_.return_raw_fft) {
|
||||||
KALDIFEAT_ERR << "return raw fft is not supported yet";
|
KALDIFEAT_ERR << "return raw fft is not supported yet";
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
package:
|
package:
|
||||||
name: kaldifeat
|
name: kaldifeat
|
||||||
version: "1.5.4"
|
version: "1.6"
|
||||||
|
|
||||||
source:
|
source:
|
||||||
path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"
|
path: "{{ environ.get('KALDIFEAT_ROOT_DIR') }}"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user