Compare commits

...

No commits in common. "v0.0.1" and "master" have entirely different histories.

168 changed files with 19062 additions and 84 deletions

13
.flake8 Normal file
View File

@ -0,0 +1,13 @@
[flake8]
max-line-length = 80
exclude =
.git,
doc,
build,
build_release,
cmake/cmake_extension.py,
kaldifeat/python/kaldifeat/__init__.py
ignore =
E402

81
.github/workflows/build-doc.yml vendored Normal file
View File

@ -0,0 +1,81 @@
# Copyright 2022 Xiaomi Corp. (author: Fangjun Kuang)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# refer to https://github.com/actions/starter-workflows/pull/47/files
# You can access it at https://csukuangfj.github.io/kaldifeat
name: Generate doc
on:
push:
branches:
- master
- doc
workflow_dispatch:
jobs:
build-doc:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: [3.8]
steps:
# refer to https://github.com/actions/checkout
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Update wheels
shell: bash
run: |
export KALDIFEAT_DIR=$PWD
ls -lh $KALDIFEAT_DIR
export GIT_LFS_SKIP_SMUDGE=1
export GIT_CLONE_PROTECTION_ACTIVE=false
git clone https://huggingface.co/csukuangfj/kaldifeat huggingface
cd huggingface
./run.sh
- name: Build doc
shell: bash
run: |
cd doc
git status
python3 -m pip install -r ./requirements.txt
make html
cp source/cpu.html build/html/
cp source/cuda.html build/html/
cp source/cpu-cn.html build/html/
cp source/cuda-cn.html build/html/
touch build/html/.nojekyll
- name: Deploy
uses: peaceiris/actions-gh-pages@v3
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: ./doc/build/html
publish_branch: gh-pages

121
.github/workflows/macos-cpu-wheels.yml vendored Normal file
View File

@ -0,0 +1,121 @@
name: build-wheels-cpu-macos
on:
push:
branches:
# - wheel
- torch-2.8.0
tags:
- '*'
workflow_dispatch:
concurrency:
group: build-wheels-cpu-macos-${{ github.ref }}
cancel-in-progress: true
jobs:
generate_build_matrix:
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generating build matrix
id: set-matrix
run: |
# outputting for debugging purposes
# python ./scripts/github_actions/generate_build_matrix.py --for-macos
# MATRIX=$(python ./scripts/github_actions/generate_build_matrix.py --for-macos)
python ./scripts/github_actions/generate_build_matrix.py --for-macos --test-only-latest-torch
MATRIX=$(python ./scripts/github_actions/generate_build_matrix.py --for-macos --test-only-latest-torch)
echo "::set-output name=matrix::${MATRIX}"
build_wheels_macos_cpu:
needs: generate_build_matrix
name: ${{ matrix.torch }} ${{ matrix.python-version }}
runs-on: macos-14
strategy:
fail-fast: false
matrix:
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
shell: bash
run: |
pip install -q torch==${{ matrix.torch}} cmake numpy wheel>=0.40.0 twine setuptools
- name: Build wheel
shell: bash
run: |
python3 setup.py bdist_wheel
mkdir wheelhouse
cp -v dist/* wheelhouse
- name: Display wheels (before fix)
shell: bash
run: |
ls -lh ./wheelhouse/
- name: Fix wheel platform tag
run: |
# See https://github.com/glencoesoftware/zeroc-ice-py-macos-x86_64/pull/3/files
# See:
# * https://github.com/pypa/wheel/issues/406
python -m wheel tags \
--platform-tag=macosx_11_0_arm64 \
--remove wheelhouse/*.whl
- name: Display wheels (after fix)
shell: bash
run: |
ls -lh ./wheelhouse/
- name: Upload Wheel
uses: actions/upload-artifact@v4
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-macos-latest-cpu
path: wheelhouse/*.whl
# https://huggingface.co/docs/hub/spaces-github-actions
- name: Publish to huggingface
if: github.repository_owner == 'csukuangfj'
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v2
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
rm -rf huggingface
export GIT_LFS_SKIP_SMUDGE=1
git clone https://huggingface.co/csukuangfj/kaldifeat huggingface
cd huggingface
git pull
d=cpu/1.25.5.dev20241029/macos
mkdir -p $d
cp -v ../wheelhouse/*.whl ./$d
git status
git lfs track "*.whl"
git add .
git commit -m "upload macos wheel for torch ${{ matrix.torch }} python ${{ matrix.python-version }}"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/kaldifeat main

View File

@ -1,26 +1,46 @@
# Copyright 2021 Xiaomi Corp. (author: Fangjun Kuang)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
name: Publish to PyPI
on:
push:
tags:
- '*'
workflow_dispatch:
jobs:
pypi:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Python
uses: actions/setup-python@v2
with:
python-version: 3.6
python-version: 3.8
- name: Install Python dependencies
shell: bash
run: |
python3 -m pip install --upgrade pip
python3 -m pip install wheel twine setuptools
python3 -m pip install torch==1.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Build
shell: bash

View File

@ -0,0 +1,85 @@
# Copyright 2021 Xiaomi Corp. (author: Fangjun Kuang)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
name: Run tests macos cpu
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
generate_build_matrix:
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generating build matrix
id: set-matrix
run: |
# outputting for debugging purposes
python scripts/github_actions/generate_build_matrix.py --test-only-latest-torch
MATRIX=$(python scripts/github_actions/generate_build_matrix.py --test-only-latest-torch)
echo "::set-output name=matrix::${MATRIX}"
run_tests_macos_cpu:
needs: generate_build_matrix
runs-on: macos-latest
strategy:
fail-fast: false
matrix:
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install PyTorch ${{ matrix.torch }}
shell: bash
run: |
python3 -m pip install -qq --upgrade pip
python3 -m pip install -qq wheel twine typing_extensions soundfile numpy
python3 -m pip install -qq torch==${{ matrix.torch }} -f https://download.pytorch.org/whl/torch_stable.html || python3 -m pip install -qq torch==${{ matrix.torch }} -f https://download.pytorch.org/whl/torch/
python3 -c "import torch; print('torch version:', torch.__version__)"
- name: Build
shell: bash
run: |
mkdir build_release
cd build_release
cmake -DCMAKE_CXX_STANDARD=17 ..
make VERBOSE=1 -j3
- name: Run tests
shell: bash
run: |
cd build_release
ctest --output-on-failure

View File

@ -0,0 +1,88 @@
# Copyright 2021 Xiaomi Corp. (author: Fangjun Kuang)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
name: Run tests ubuntu cpu
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
generate_build_matrix:
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generating build matrix
id: set-matrix
run: |
# outputting for debugging purposes
python scripts/github_actions/generate_build_matrix.py --test-only-latest-torch
MATRIX=$(python scripts/github_actions/generate_build_matrix.py --test-only-latest-torch)
echo "::set-output name=matrix::${MATRIX}"
run_tests_ubuntu_cpu:
needs: generate_build_matrix
runs-on: ubuntu-18.04
strategy:
fail-fast: false
matrix:
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install PyTorch ${{ matrix.torch }}
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y libsndfile1-dev libsndfile1 ffmpeg
python3 -m pip install --upgrade pip
python3 -m pip install wheel twine typing_extensions soundfile
python3 -m pip install bs4 requests tqdm numpy
python3 -m pip install -qq torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html || python3 -m pip install -qq torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch/
python3 -c "import torch; print('torch version:', torch.__version__)"
- name: Build
shell: bash
run: |
mkdir build_release
cd build_release
cmake -DCMAKE_CXX_STANDARD=17 ..
make VERBOSE=1 -j3
- name: Run tests
shell: bash
run: |
cd build_release
ctest --output-on-failure

View File

@ -0,0 +1,112 @@
# Copyright 2021 Xiaomi Corp. (author: Fangjun Kuang)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
name: Run tests ubuntu cuda
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
generate_build_matrix:
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generating build matrix
id: set-matrix
run: |
# outputting for debugging purposes
python scripts/github_actions/generate_build_matrix.py --enable-cuda --test-only-latest-torch
MATRIX=$(python scripts/github_actions/generate_build_matrix.py --enable-cuda --test-only-latest-torch)
echo "::set-output name=matrix::${MATRIX}"
run_tests_ubuntu_cuda:
needs: generate_build_matrix
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install CUDA Toolkit ${{ matrix.cuda }}
shell: bash
env:
cuda: ${{ matrix.cuda }}
run: |
source ./scripts/github_actions/install_cuda.sh
echo "CUDA_HOME=${CUDA_HOME}" >> $GITHUB_ENV
echo "${CUDA_HOME}/bin" >> $GITHUB_PATH
echo "LD_LIBRARY_PATH=${CUDA_HOME}/lib:${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}" >> $GITHUB_ENV
- name: Display NVCC version
run: |
which nvcc
nvcc --version
- name: Install PyTorch ${{ matrix.torch }}
env:
cuda: ${{ matrix.cuda }}
torch: ${{ matrix.torch }}
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y libsndfile1-dev libsndfile1 ffmpeg
python3 -m pip install --upgrade pip
python3 -m pip install wheel twine typing_extensions soundfile
python3 -m pip install bs4 requests tqdm numpy
./scripts/github_actions/install_torch.sh
python3 -c "import torch; print('torch version:', torch.__version__)"
- name: Download cudnn 8.0
env:
cuda: ${{ matrix.cuda }}
run: |
./scripts/github_actions/install_cudnn.sh
- name: Build
shell: bash
run: |
mkdir build_release
cd build_release
cmake -DCMAKE_CXX_STANDARD=17 ..
make VERBOSE=1 -j3
- name: Run tests
shell: bash
run: |
cd build_release
ctest --output-on-failure

View File

@ -0,0 +1,121 @@
# Copyright 2021 Xiaomi Corp. (author: Fangjun Kuang)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
name: Run tests windows cpu
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
generate_build_matrix:
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generating build matrix
id: set-matrix
run: |
# outputting for debugging purposes
python scripts/github_actions/generate_build_matrix.py --test-only-latest-torch
MATRIX=$(python scripts/github_actions/generate_build_matrix.py --test-only-latest-torch)
echo "::set-output name=matrix::${MATRIX}"
run_tests_windows_cpu:
# see https://github.com/actions/virtual-environments/blob/win19/20210525.0/images/win/Windows2019-Readme.md
needs: generate_build_matrix
runs-on: windows-latest
strategy:
fail-fast: false
matrix:
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
# see https://github.com/microsoft/setup-msbuild
- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v1.0.2
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Install PyTorch ${{ matrix.torch }}
run: |
pip3 install -qq torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html || pip3 install -qq torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch/
pip3 install -qq wheel twine dataclasses numpy typing_extensions soundfile
- name: Display CMake version
run: |
cmake --version
cmake --help
- name: Configure CMake
shell: bash
run: |
mkdir build_release
cd build_release
cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE ..
ls -lh
- name: Build kaldifeat
run: |
cd build_release
cmake --build -DCMAKE_CXX_STANDARD=17 . --target _kaldifeat --config Release
- name: Display generated files
shell: bash
run: |
cd build_release
ls -lh lib/*/*
- name: Build wheel
shell: bash
run: |
python3 setup.py bdist_wheel
ls -lh dist/
pip install ./dist/*.whl
python3 -c "import kaldifeat; print(kaldifeat.__version__)"
- name: Upload Wheel
uses: actions/upload-artifact@v4
with:
name: python-${{ matrix.python-version }}-${{ matrix.os }}-cpu
path: dist/*.whl
- name: Build tests
shell: bash
run: |
cd build_release
cmake -DCMAKE_CXX_STANDARD=17 --build . --target ALL_BUILD --config Release
ls -lh bin/*/*
ctest -C Release --verbose --output-on-failure

View File

@ -0,0 +1,173 @@
# Copyright 2021 Xiaomi Corp. (author: Fangjun Kuang)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: Run tests windows cuda
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
generate_build_matrix:
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generating build matrix
id: set-matrix
run: |
# outputting for debugging purposes
python scripts/github_actions/generate_build_matrix.py --enable-cuda --for-windows --test-only-latest-torch
MATRIX=$(python scripts/github_actions/generate_build_matrix.py --enable-cuda --for-windows --test-only-latest-torch)
echo "::set-output name=matrix::${MATRIX}"
run_tests_windows_cuda:
needs: generate_build_matrix
runs-on: windows-latest
strategy:
fail-fast: false
matrix:
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
# see https://github.com/microsoft/setup-msbuild
- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v1.0.2
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Display Python version
run: python -c "import sys; print(sys.version)"
# See https://github.com/Jimver/cuda-toolkit/blob/master/src/links/windows-links.ts
# for available CUDA versions
- uses: Jimver/cuda-toolkit@v0.2.7
id: cuda-toolkit
with:
cuda: ${{ matrix.cuda }}
- name: Display CUDA version
shell: bash
run: |
echo "Installed cuda version is: ${{ steps.cuda-toolkit.outputs.cuda }}"
echo "Cuda install location: ${{steps.cuda-toolkit.outputs.CUDA_PATH}}"
nvcc --version
- name: Remove CUDA installation package
shell: bash
run: |
rm "C:/hostedtoolcache/windows/cuda_installer-windows/${{ matrix.cuda }}/x64/cuda_installer_${{ matrix.cuda }}.exe"
- name: Download cuDNN
shell: bash
run: |
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/cudnn-for-windows
cd cudnn-for-windows
git lfs pull --include="cudnn-windows-x86_64-8.4.1.50_cuda11.6-archive.zip"
unzip cudnn-windows-x86_64-8.4.1.50_cuda11.6-archive.zip
rm cudnn-windows-x86_64-8.4.1.50_cuda11.6-archive.zip
ls -lh *
ls -lh */*
echo "PWD: $PWD"
- name: Install PyTorch ${{ matrix.torch }}
shell: bash
run: |
version=${{ matrix.cuda }}
major=${version:0:2}
minor=${version:3:1}
v=${major}${minor}
if [ ${v} -eq 102 ]; then v=""; else v="+cu${v}"; fi
python3 -m pip install -qq --upgrade pip
python3 -m pip install -qq wheel twine numpy typing_extensions
python3 -m pip install -qq dataclasses soundfile numpy
python3 -m pip install -qq torch==${{ matrix.torch }}${v} -f https://download.pytorch.org/whl/torch_stable.html numpy || python3 -m pip install -qq torch==${{ matrix.torch }}${v} -f https://download.pytorch.org/whl/torch/ numpy
python3 -c "import torch; print('torch version:', torch.__version__)"
python3 -m torch.utils.collect_env
- name: Display CMake version
run: |
cmake --version
cmake --help
- name: Configure CMake
shell: bash
run: |
echo "PWD: $PWD"
ls -lh
mkdir build_release
cd build_release
cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DCUDNN_INCLUDE_PATH=d:/a/kaldifeat/kaldifeat/cudnn-for-windows/cudnn-windows-x86_64-8.4.1.50_cuda11.6-archive/include -DCUDNN_LIBRARY_PATH=d:/a/kaldifeat/kaldifeat/cudnn-for-windows/cudnn-windows-x86_64-8.4.1.50_cuda11.6-archive/lib/cudnn.lib ..
ls -lh
- name: Build kaldifeat
shell: bash
run: |
cd build_release
cmake --build . --target _kaldifeat --config Release
- name: Display generated files
shell: bash
run: |
cd build_release
ls -lh lib/*/*
- name: Build wheel
shell: bash
run: |
echo $PWD
ls -lh ./*
export KALDIFEAT_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=$BUILD_TYPE -DCUDNN_INCLUDE_PATH=d:/a/kaldifeat/kaldifeat/cudnn-for-windows/cudnn-windows-x86_64-8.4.1.50_cuda11.6-archive/include -DCUDNN_LIBRARY_PATH=d:/a/kaldifeat/kaldifeat/cudnn-for-windows/cudnn-windows-x86_64-8.4.1.50_cuda11.6-archive/lib/cudnn.lib"
python3 setup.py bdist_wheel
ls -lh dist/
pip install ./dist/*.whl
python3 -c "import kaldifeat; print(kaldifeat.__version__)"
- name: Upload Wheel
uses: actions/upload-artifact@v4
with:
name: python-${{ matrix.python-version }}-${{ matrix.os }}-cuda-${{ matrix.cuda }}
path: dist/*.whl
- name: Build tests
shell: bash
run: |
cd build_release
cmake -DCMAKE_CXX_STANDARD=17 --build . --target ALL_BUILD --config Release
ls -lh bin/*/*
ctest -C Release --verbose --output-on-failure

64
.github/workflows/style_check.yml vendored Normal file
View File

@ -0,0 +1,64 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: style_check
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
style_check:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
python-version: ["3.8"]
fail-fast: false
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
# See https://github.com/psf/black/issues/2964
# The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
- name: Run flake8
shell: bash
working-directory: ${{github.workspace}}
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --show-source --statistics
flake8 .
- name: Run black
shell: bash
working-directory: ${{github.workspace}}
run: |
black --check --diff .

67
.github/workflows/test-wheels.yml vendored Normal file
View File

@ -0,0 +1,67 @@
name: Test pre-compiled wheels
on:
workflow_dispatch:
inputs:
torch_version:
description: "torch version, e.g., 2.0.1"
required: true
kaldifeat_version:
description: "kaldifeat version, e.g., 1.25.0.dev20230726"
required: true
jobs:
Test_pre_compiled_wheels:
name: ${{ matrix.os }} ${{ github.event.inputs.torch_version }} ${{ github.event.inputs.kaldifeat_version }} ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.8", "3.9", "3.10"]
steps:
# refer to https://github.com/actions/checkout
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Install dependencies
shell: bash
run: |
pip install numpy
- name: Install torch
if: startsWith(matrix.os, 'macos')
shell: bash
run: |
pip install torch==${{ github.event.inputs.torch_version }}
- name: Install torch
if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'windows')
shell: bash
run: |
pip install torch==${{ github.event.inputs.torch_version }}+cpu -f https://download.pytorch.org/whl/torch_stable.html || pip install torch==${{ github.event.inputs.torch_version }}+cpu -f https://download.pytorch.org/whl/torch/
- name: Install kaldifeat
shell: bash
run: |
pip install kaldifeat==${{ github.event.inputs.kaldifeat_version }}+cpu.torch${{ github.event.inputs.torch_version }} -f https://csukuangfj.github.io/kaldifeat/cpu.html
- name: Run tests
shell: bash
run: |
cd kaldifeat/python/tests
python3 -c "import kaldifeat; print(kaldifeat.__file__)"
python3 -c "import kaldifeat; print(kaldifeat.__version__)"
python3 ./test_fbank_options.py
python3 ./test_mfcc_options.py

View File

@ -0,0 +1,168 @@
name: build-wheels-cpu-arm64-ubuntu
on:
push:
branches:
# - wheel
- torch-2.8.0
tags:
- '*'
workflow_dispatch:
concurrency:
group: build-wheels-cpu-arm64-ubuntu-${{ github.ref }}
cancel-in-progress: true
jobs:
generate_build_matrix:
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generating build matrix
id: set-matrix
run: |
# outputting for debugging purposes
# python ./scripts/github_actions/generate_build_matrix.py --for-arm64
# MATRIX=$(python ./scripts/github_actions/generate_build_matrix.py --for-arm64)
python ./scripts/github_actions/generate_build_matrix.py --test-only-latest-torch --for-arm64
MATRIX=$(python ./scripts/github_actions/generate_build_matrix.py --test-only-latest-torch --for-arm64)
echo "::set-output name=matrix::${MATRIX}"
build-manylinux-wheels:
needs: generate_build_matrix
name: ${{ matrix.torch }} ${{ matrix.python-version }}
runs-on: ubuntu-22.04-arm
strategy:
fail-fast: false
matrix:
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
# see https://github.com/pytorch/test-infra/blob/9e3d392690719fac85bad0c9b67f530e48375ca1/tools/scripts/generate_binary_build_matrix.py
# https://github.com/pytorch/builder/tree/main/manywheel
# https://github.com/pytorch/builder/pull/476
# https://github.com/k2-fsa/k2/issues/733
# https://github.com/pytorch/pytorch/pull/50633 (generate build matrix)
- name: Run the build process with Docker
uses: addnab/docker-run-action@v3
with:
image: ${{ matrix.image }}
options: -v ${{ github.workspace }}:/var/www -e IS_2_28=${{ matrix.is_2_28 }} -e PYTHON_VERSION=${{ matrix.python-version }} -e TORCH_VERSION=${{ matrix.torch }}
run: |
echo "pwd: $PWD"
uname -a
id
cat /etc/*release
gcc --version
python3 --version
which python3
ls -lh /opt/python/
echo "---"
ls -lh /opt/python/cp*
ls -lh /opt/python/*/bin
echo "---"
find /opt/python/cp* -name "libpython*"
echo "-----"
find /opt/_internal/cp* -name "libpython*"
echo "-----"
find / -name "libpython*"
echo "----"
ls -lh /usr/lib64/libpython3.so
# cp36-cp36m
# cp37-cp37m
# cp38-cp38
# cp39-cp39
# cp310-cp310
# cp311-cp311
# cp312-cp312
# cp313-cp313
# cp313-cp313t (no gil)
if [[ $PYTHON_VERSION == "3.6" ]]; then
python_dir=/opt/python/cp36-cp36m
export PYTHONPATH=/opt/python/cp36-cp36m/lib/python3.6/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.7" ]]; then
python_dir=/opt/python/cp37-cp37m
export PYTHONPATH=/opt/python/cp37-cp37m/lib/python3.7/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.8" ]]; then
python_dir=/opt/python/cp38-cp38
export PYTHONPATH=/opt/python/cp38-cp38/lib/python3.8/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.9" ]]; then
python_dir=/opt/python/cp39-cp39
export PYTHONPATH=/opt/python/cp39-cp39/lib/python3.9/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.10" ]]; then
python_dir=/opt/python/cp310-cp310
export PYTHONPATH=/opt/python/cp310-cp310/lib/python3.10/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.11" ]]; then
python_dir=/opt/python/cp311-cp311
export PYTHONPATH=/opt/python/cp311-cp311/lib/python3.11/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.12" ]]; then
python_dir=/opt/python/cp312-cp312
export PYTHONPATH=/opt/python/cp312-cp312/lib/python3.12/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.13" ]]; then
python_dir=/opt/python/cp313-cp313
export PYTHONPATH=/opt/python/cp313-cp313/lib/python3.13/site-packages:$PYTHONPATH
else
echo "Unsupported Python version $PYTHON_VERSION"
exit 1
fi
export PYTHON_INSTALL_DIR=$python_dir
export PATH=$PYTHON_INSTALL_DIR/bin:$PATH
python3 --version
which python3
/var/www/scripts/github_actions/build-ubuntu-cpu-arm64.sh
- name: Display wheels
shell: bash
run: |
ls -lh ./wheelhouse/
# https://huggingface.co/docs/hub/spaces-github-actions
- name: Publish to huggingface
if: github.repository_owner == 'csukuangfj'
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v2
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
rm -rf huggingface
export GIT_LFS_SKIP_SMUDGE=1
git clone https://huggingface.co/csukuangfj/kaldifeat huggingface
cd huggingface
git pull
d=cpu/1.25.5.dev20250307/linux-arm64
mkdir -p $d
cp -v ../wheelhouse/*.whl ./$d
git status
git lfs track "*.whl"
git add .
git commit -m "upload ubuntu-arm64-cpu wheel for torch ${{ matrix.torch }} python ${{ matrix.python-version }}"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/kaldifeat main

168
.github/workflows/ubuntu-cpu-wheels.yml vendored Normal file
View File

@ -0,0 +1,168 @@
name: build-wheels-cpu-ubuntu
on:
push:
branches:
# - wheel
- torch-2.8.0
tags:
- '*'
workflow_dispatch:
concurrency:
group: build-wheels-cpu-ubuntu-${{ github.ref }}
cancel-in-progress: true
jobs:
generate_build_matrix:
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generating build matrix
id: set-matrix
run: |
# outputting for debugging purposes
# python ./scripts/github_actions/generate_build_matrix.py
# MATRIX=$(python ./scripts/github_actions/generate_build_matrix.py)
python ./scripts/github_actions/generate_build_matrix.py --test-only-latest-torch
MATRIX=$(python ./scripts/github_actions/generate_build_matrix.py --test-only-latest-torch)
echo "::set-output name=matrix::${MATRIX}"
build-manylinux-wheels:
needs: generate_build_matrix
name: ${{ matrix.torch }} ${{ matrix.python-version }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
# see https://github.com/pytorch/test-infra/blob/9e3d392690719fac85bad0c9b67f530e48375ca1/tools/scripts/generate_binary_build_matrix.py
# https://github.com/pytorch/builder/tree/main/manywheel
# https://github.com/pytorch/builder/pull/476
# https://github.com/k2-fsa/k2/issues/733
# https://github.com/pytorch/pytorch/pull/50633 (generate build matrix)
- name: Run the build process with Docker
uses: addnab/docker-run-action@v3
with:
image: ${{ matrix.image }}
options: -v ${{ github.workspace }}:/var/www -e IS_2_28=${{ matrix.is_2_28 }} -e PYTHON_VERSION=${{ matrix.python-version }} -e TORCH_VERSION=${{ matrix.torch }}
run: |
echo "pwd: $PWD"
uname -a
id
cat /etc/*release
gcc --version
python3 --version
which python3
ls -lh /opt/python/
echo "---"
ls -lh /opt/python/cp*
ls -lh /opt/python/*/bin
echo "---"
find /opt/python/cp* -name "libpython*"
echo "-----"
find /opt/_internal/cp* -name "libpython*"
echo "-----"
find / -name "libpython*"
echo "----"
ls -lh /usr/lib64/libpython3.so || true
# cp36-cp36m
# cp37-cp37m
# cp38-cp38
# cp39-cp39
# cp310-cp310
# cp311-cp311
# cp312-cp312
# cp313-cp313
# cp313-cp313t (no gil)
if [[ $PYTHON_VERSION == "3.6" ]]; then
python_dir=/opt/python/cp36-cp36m
export PYTHONPATH=/opt/python/cp36-cp36m/lib/python3.6/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.7" ]]; then
python_dir=/opt/python/cp37-cp37m
export PYTHONPATH=/opt/python/cp37-cp37m/lib/python3.7/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.8" ]]; then
python_dir=/opt/python/cp38-cp38
export PYTHONPATH=/opt/python/cp38-cp38/lib/python3.8/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.9" ]]; then
python_dir=/opt/python/cp39-cp39
export PYTHONPATH=/opt/python/cp39-cp39/lib/python3.9/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.10" ]]; then
python_dir=/opt/python/cp310-cp310
export PYTHONPATH=/opt/python/cp310-cp310/lib/python3.10/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.11" ]]; then
python_dir=/opt/python/cp311-cp311
export PYTHONPATH=/opt/python/cp311-cp311/lib/python3.11/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.12" ]]; then
python_dir=/opt/python/cp312-cp312
export PYTHONPATH=/opt/python/cp312-cp312/lib/python3.12/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.13" ]]; then
python_dir=/opt/python/cp313-cp313
export PYTHONPATH=/opt/python/cp313-cp313/lib/python3.13/site-packages:$PYTHONPATH
else
echo "Unsupported Python version $PYTHON_VERSION"
exit 1
fi
export PYTHON_INSTALL_DIR=$python_dir
export PATH=$PYTHON_INSTALL_DIR/bin:$PATH
python3 --version
which python3
/var/www/scripts/github_actions/build-ubuntu-cpu.sh
- name: Display wheels
shell: bash
run: |
ls -lh ./wheelhouse/
# https://huggingface.co/docs/hub/spaces-github-actions
- name: Publish to huggingface
if: github.repository_owner == 'csukuangfj'
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v2
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
rm -rf huggingface
export GIT_LFS_SKIP_SMUDGE=1
git clone https://huggingface.co/csukuangfj/kaldifeat huggingface
cd huggingface
git pull
d=cpu/1.25.5.dev20250307/linux-x64
mkdir -p $d
cp -v ../wheelhouse/*.whl ./$d
git status
git lfs track "*.whl"
git add .
git commit -m "upload ubuntu-cpu wheel for torch ${{ matrix.torch }} python ${{ matrix.python-version }}"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/kaldifeat main

194
.github/workflows/ubuntu-cuda-wheels.yml vendored Normal file
View File

@ -0,0 +1,194 @@
name: build-wheels-cuda-ubuntu
on:
push:
branches:
- wheel
# - torch-2.7.1
tags:
- '*'
workflow_dispatch:
concurrency:
group: build-wheels-cuda-ubuntu-${{ github.ref }}
cancel-in-progress: true
jobs:
generate_build_matrix:
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generating build matrix
id: set-matrix
run: |
# outputting for debugging purposes
# python ./scripts/github_actions/generate_build_matrix.py --enable-cuda
# MATRIX=$(python ./scripts/github_actions/generate_build_matrix.py --enable-cuda)
python ./scripts/github_actions/generate_build_matrix.py --enable-cuda --test-only-latest-torch
MATRIX=$(python ./scripts/github_actions/generate_build_matrix.py --enable-cuda --test-only-latest-torch)
echo "::set-output name=matrix::${MATRIX}"
build-manylinux-wheels:
needs: generate_build_matrix
name: ${{ matrix.torch }} ${{ matrix.python-version }} cuda${{ matrix.cuda }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Free space
shell: bash
run: |
df -h
rm -rf /opt/hostedtoolcache
df -h
echo "pwd: $PWD"
echo "github.workspace ${{ github.workspace }}"
# see https://github.com/pytorch/test-infra/blob/9e3d392690719fac85bad0c9b67f530e48375ca1/tools/scripts/generate_binary_build_matrix.py
# https://github.com/pytorch/builder/tree/main/manywheel
# https://github.com/pytorch/builder/pull/476
# https://github.com/k2-fsa/k2/issues/733
# https://github.com/pytorch/pytorch/pull/50633 (generate build matrix)
- name: Run the build process with Docker
uses: addnab/docker-run-action@v3
with:
image: ${{ matrix.image }}
options: -v ${{ github.workspace }}:/var/www -e IS_2_28=${{ matrix.is_2_28 }} -e PYTHON_VERSION=${{ matrix.python-version }} -e TORCH_VERSION=${{ matrix.torch }} -e CUDA_VERSION=${{ matrix.cuda }}
run: |
echo "pwd: $PWD"
uname -a
id
cat /etc/*release
gcc --version
python3 --version
which python3
ls -lh /opt/python/
echo "---"
ls -lh /opt/python/cp*
ls -lh /opt/python/*/bin
echo "---"
find /opt/python/cp* -name "libpython*"
echo "-----"
find /opt/_internal/cp* -name "libpython*"
echo "-----"
find / -name "libpython*"
# cp36-cp36m
# cp37-cp37m
# cp38-cp38
# cp39-cp39
# cp310-cp310
# cp311-cp311
# cp312-cp312
# cp313-cp313
# cp313-cp313t (no gil)
if [[ $PYTHON_VERSION == "3.6" ]]; then
python_dir=/opt/python/cp36-cp36m
export PYTHONPATH=/opt/python/cp36-cp36m/lib/python3.6/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.7" ]]; then
python_dir=/opt/python/cp37-cp37m
export PYTHONPATH=/opt/python/cp37-cp37m/lib/python3.7/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.8" ]]; then
python_dir=/opt/python/cp38-cp38
export PYTHONPATH=/opt/python/cp38-cp38/lib/python3.8/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.9" ]]; then
python_dir=/opt/python/cp39-cp39
export PYTHONPATH=/opt/python/cp39-cp39/lib/python3.9/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.10" ]]; then
python_dir=/opt/python/cp310-cp310
export PYTHONPATH=/opt/python/cp310-cp310/lib/python3.10/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.11" ]]; then
python_dir=/opt/python/cp311-cp311
export PYTHONPATH=/opt/python/cp311-cp311/lib/python3.11/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.12" ]]; then
python_dir=/opt/python/cp312-cp312
export PYTHONPATH=/opt/python/cp312-cp312/lib/python3.12/site-packages:$PYTHONPATH
elif [[ $PYTHON_VERSION == "3.13" ]]; then
python_dir=/opt/python/cp313-cp313
export PYTHONPATH=/opt/python/cp313-cp313/lib/python3.13/site-packages:$PYTHONPATH
else
echo "Unsupported Python version $PYTHON_VERSION"
exit 1
fi
export PYTHON_INSTALL_DIR=$python_dir
export PATH=$PYTHON_INSTALL_DIR/bin:$PATH
# There are no libpython.so inside $PYTHON_INSTALL_DIR
# since they are statically linked.
python3 --version
which python3
pushd /usr/local
rm cuda
ln -s cuda-$CUDA_VERSION cuda
popd
which nvcc
nvcc --version
cp /var/www/scripts/github_actions/install_torch.sh .
chmod +x install_torch.sh
/var/www/scripts/github_actions/build-ubuntu-cuda.sh
- name: Display wheels
shell: bash
run: |
ls -lh ./wheelhouse/
- name: Upload Wheel
if: false
uses: actions/upload-artifact@v4
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cuda-is_2_28-${{ matrix.is_2_28 }}
path: wheelhouse/*.whl
# https://huggingface.co/docs/hub/spaces-github-actions
- name: Publish to huggingface
if: github.repository_owner == 'csukuangfj'
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v2
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
rm -rf huggingface
export GIT_LFS_SKIP_SMUDGE=1
git clone https://huggingface.co/csukuangfj/kaldifeat huggingface
cd huggingface
git pull
d=cuda/1.25.5.dev20241029/linux
mkdir -p $d
cp -v ../wheelhouse/*.whl ./$d
git status
git lfs track "*.whl"
git add .
git commit -m "upload ubuntu-cuda wheel for torch ${{ matrix.torch }} python ${{ matrix.python-version }}"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/kaldifeat main

View File

@ -0,0 +1,108 @@
name: build-wheels-cpu-win64
on:
push:
branches:
# - wheel
- torch-2.8.0
tags:
- '*'
workflow_dispatch:
concurrency:
group: build-wheels-cpu-win64-${{ github.ref }}
cancel-in-progress: true
jobs:
generate_build_matrix:
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generating build matrix
id: set-matrix
run: |
# outputting for debugging purposes
# python ./scripts/github_actions/generate_build_matrix.py --for-windows
# MATRIX=$(python ./scripts/github_actions/generate_build_matrix.py --for-windows)
python ./scripts/github_actions/generate_build_matrix.py --for-windows --test-only-latest-torch
MATRIX=$(python ./scripts/github_actions/generate_build_matrix.py --for-windows --test-only-latest-torch)
echo "::set-output name=matrix::${MATRIX}"
build_wheels_win64_cpu:
needs: generate_build_matrix
name: ${{ matrix.torch }} ${{ matrix.python-version }}
runs-on: windows-latest
strategy:
fail-fast: false
matrix:
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
shell: bash
run: |
pip install -q torch==${{ matrix.torch}} cmake numpy wheel>=0.40.0 twine setuptools
pip install torch==${{ matrix.torch}}+cpu -f https://download.pytorch.org/whl/torch_stable.html cmake numpy || pip install torch==${{ matrix.torch}}+cpu -f https://download.pytorch.org/whl/torch/ cmake numpy
- name: Build wheel
shell: bash
run: |
python3 setup.py bdist_wheel
mkdir wheelhouse
cp -v dist/* wheelhouse
- name: Display wheels
shell: bash
run: |
ls -lh ./wheelhouse/
- name: Upload Wheel
uses: actions/upload-artifact@v4
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-windows-latest-cpu
path: wheelhouse/*.whl
# https://huggingface.co/docs/hub/spaces-github-actions
- name: Publish to huggingface
if: github.repository_owner == 'csukuangfj'
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v2
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
rm -rf huggingface
export GIT_LFS_SKIP_SMUDGE=1
git clone https://huggingface.co/csukuangfj/kaldifeat huggingface
cd huggingface
git pull
d=cpu/1.25.5.dev20241029/windows
mkdir -p $d
cp -v ../wheelhouse/*.whl ./$d
git status
git lfs track "*.whl"
git add .
git commit -m "upload windows-cpu wheel for torch ${{ matrix.torch }} python ${{ matrix.python-version }}"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/kaldifeat main

7
.gitignore vendored
View File

@ -1,3 +1,10 @@
build/
build*/
*.egg-info*/
dist/
__pycache__/
test-1hour.wav
path.sh
torch_version.py
cpu*.html
cuda*.html

26
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,26 @@
repos:
- repo: https://github.com/psf/black
rev: 21.6b0
hooks:
- id: black
args: [--line-length=80]
- repo: https://github.com/PyCQA/flake8
rev: 3.9.2
hooks:
- id: flake8
args: [--max-line-length=80]
- repo: https://github.com/pycqa/isort
rev: 5.9.2
hooks:
- id: isort
args: [--profile=black]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: check-executables-have-shebangs
- id: end-of-file-fixer
- id: mixed-line-ending
- id: trailing-whitespace

View File

@ -1,10 +1,16 @@
# Copyright (c) 2021 Xiaomi Corporation (author: Fangjun Kuang)
if (CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0")
set(CMAKE_POLICY_VERSION_MINIMUM 3.5)
endif()
cmake_minimum_required(VERSION 3.8 FATAL_ERROR)
project(kaldifeat)
set(kaldifeat_VERSION "0.0.1")
# remember to change the version in
# scripts/conda/kaldifeat/meta.yaml
# scripts/conda-cpu/kaldifeat/meta.yaml
set(kaldifeat_VERSION "1.25.5")
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
@ -13,24 +19,102 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin")
set(CMAKE_SKIP_BUILD_RPATH FALSE)
set(BUILD_RPATH_USE_ORIGIN TRUE)
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
set(CMAKE_INSTALL_RPATH "$ORIGIN")
set(CMAKE_BUILD_RPATH "$ORIGIN")
if(NOT APPLE)
set(kaldifeat_rpath_origin "$ORIGIN")
else()
set(kaldifeat_rpath_origin "@loader_path")
endif()
set(CMAKE_INSTALL_RPATH ${kaldifeat_rpath_origin})
set(CMAKE_BUILD_RPATH ${kaldifeat_rpath_origin})
if(NOT CMAKE_BUILD_TYPE)
message(STATUS "No CMAKE_BUILD_TYPE given, default to Release")
set(CMAKE_BUILD_TYPE Release)
endif()
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
if (NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ version to be used.")
endif()
message(STATUS "C++ Standard version: ${CMAKE_CXX_STANDARD}")
set(CMAKE_CXX_EXTENSIONS OFF)
message(STATUS "C++ Standard version: ${CMAKE_CXX_STANDARD}")
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/Modules)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
option(BUILD_SHARED_LIBS "Whether to build shared libraries" ON)
option(kaldifeat_BUILD_TESTS "Whether to build tests or not" OFF)
option(kaldifeat_BUILD_PYMODULE "Whether to build python module or not" ON)
include(pybind11)
message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}")
if(BUILD_SHARED_LIBS AND MSVC)
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif()
if(kaldifeat_BUILD_PYMODULE)
include(pybind11)
endif()
# to prevent cmake from trying to link with system installed mkl since we not directly use it
# mkl libraries should be linked with pytorch already
# ref: https://github.com/pytorch/pytorch/blob/master/cmake/public/mkl.cmake
set(CMAKE_DISABLE_FIND_PACKAGE_MKL TRUE)
include(torch)
include_directories(${CMAKE_SOURCE_DIR})
if(kaldifeat_BUILD_TESTS)
include(googletest)
enable_testing()
endif()
if(WIN32)
# disable various warnings for MSVC
# 4624: destructor was implicitly defined as deleted because a base class destructor is inaccessible or deleted
set(disabled_warnings
/wd4624
)
message(STATUS "Disabled warnings: ${disabled_warnings}")
foreach(w IN LISTS disabled_warnings)
string(APPEND CMAKE_CXX_FLAGS " ${w} ")
endforeach()
endif()
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}")
add_subdirectory(kaldifeat)
# TORCH_VERSION is defined in cmake/torch.cmake
configure_file(
${PROJECT_SOURCE_DIR}/kaldifeat/python/kaldifeat/torch_version.py.in
${PROJECT_SOURCE_DIR}/kaldifeat/python/kaldifeat/torch_version.py @ONLY
)
configure_file(
${PROJECT_SOURCE_DIR}/cmake/kaldifeatConfigVersion.cmake.in
${PROJECT_BINARY_DIR}/kaldifeatConfigVersion.cmake
@ONLY
)
configure_file(
${PROJECT_SOURCE_DIR}/cmake/kaldifeatConfig.cmake.in
${PROJECT_BINARY_DIR}/kaldifeatConfig.cmake
@ONLY
)
install(FILES
${PROJECT_BINARY_DIR}/kaldifeatConfigVersion.cmake
${PROJECT_BINARY_DIR}/kaldifeatConfig.cmake
DESTINATION share/cmake/kaldifeat
)
install(FILES
${PROJECT_SOURCE_DIR}/kaldifeat/python/kaldifeat/torch_version.py
DESTINATION ./
)

6
MANIFEST.in Normal file
View File

@ -0,0 +1,6 @@
include LICENSE
include README.md
include CMakeLists.txt
exclude pyproject.toml
recursive-include kaldifeat *.*
recursive-include cmake *.*

324
README.md
View File

@ -1,11 +1,329 @@
# kaldifeat
Wrap kaldi's feature computations to Python with PyTorch support.
<div align="center">
<img src="/doc/source/images/os-green.svg">
<img src="/doc/source/images/python_ge_3.6-blue.svg">
<img src="/doc/source/images/pytorch_ge_1.5.0-green.svg">
<img src="/doc/source/images/cuda_ge_10.1-orange.svg">
</div>
[![Documentation Status](https://github.com/csukuangfj/kaldifeat/actions/workflows/build-doc.yml/badge.svg)](https://csukuangfj.github.io/kaldifeat/)
**Documentation**: <https://csukuangfj.github.io/kaldifeat>
**Note**: If you are looking for a version that does not depend on PyTorch,
please see <https://github.com/csukuangfj/kaldi-native-fbank>
# Installation
`kaldifeat` can be installed by
Refer to
<https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html>
for installation.
> Never use `pip install kaldifeat`
> Never use `pip install kaldifeat`
> Never use `pip install kaldifeat`
<sub>
<table>
<tr>
<th>Comments</th>
<th>Options</th>
<th>Feature Computer</th>
<th>Usage</th>
</tr>
<tr>
<td>Fbank for <a href="https://github.com/openai/whisper">Whisper</a></td>
<td><code>kaldifeat.WhisperFbankOptions</code></td>
<td><code>kaldifeat.WhisperFbank</code></td>
<td>
<pre lang="python">
opts = kaldifeat.WhisperFbankOptions()
opts.device = torch.device('cuda', 0)
fbank = kaldifeat.WhisperFbank(opts)
features = fbank(wave)
</pre>
See <a href="https://github.com/csukuangfj/kaldifeat/pull/82">#82</a>
</td>
</tr>
<tr>
<td>Fbank for <a href="https://github.com/openai/whisper">Whisper-V3</a></td>
<td><code>kaldifeat.WhisperFbankOptions</code></td>
<td><code>kaldifeat.WhisperFbank</code></td>
<td>
<pre lang="python">
opts = kaldifeat.WhisperFbankOptions()
opts.num_mels = 128
opts.device = torch.device('cuda', 0)
fbank = kaldifeat.WhisperFbank(opts)
features = fbank(wave)
</pre>
</td>
</tr>
<tr>
<td>FBANK</td>
<td><code>kaldifeat.FbankOptions</code></td>
<td><code>kaldifeat.Fbank</code></td>
<td>
<pre lang="python">
opts = kaldifeat.FbankOptions()
opts.device = torch.device('cuda', 0)
opts.frame_opts.window_type = 'povey'
fbank = kaldifeat.Fbank(opts)
features = fbank(wave)
</pre>
</td>
</tr>
<tr>
<td>Streaming FBANK</td>
<td><code>kaldifeat.FbankOptions</code></td>
<td><code>kaldifeat.OnlineFbank</code></td>
<td>
See <a href="./kaldifeat/python/tests/test_fbank.py">
./kaldifeat/python/tests/test_fbank.py
</a>
</td>
</tr>
<tr>
<td>MFCC</td>
<td><code>kaldifeat.MfccOptions</code></td>
<td><code>kaldifeat.Mfcc</code></td>
<td>
<pre lang="python">
opts = kaldifeat.MfccOptions();
opts.num_ceps = 13
mfcc = kaldifeat.Mfcc(opts)
features = mfcc(wave)
</pre>
</td>
</tr>
<tr>
<td>Streaming MFCC</td>
<td><code>kaldifeat.MfccOptions</code></td>
<td><code>kaldifeat.OnlineMfcc</code></td>
<td>
See <a href="./kaldifeat/python/tests/test_mfcc.py">
./kaldifeat/python/tests/test_mfcc.py
</a>
</td>
</tr>
<tr>
<td>PLP</td>
<td><code>kaldifeat.PlpOptions</code></td>
<td><code>kaldifeat.Plp</code></td>
<td>
<pre lang="python">
opts = kaldifeat.PlpOptions();
opts.mel_opts.num_bins = 23
plp = kaldifeat.Plp(opts)
features = plp(wave)
</pre>
</td>
</tr>
<tr>
<td>Streaming PLP</td>
<td><code>kaldifeat.PlpOptions</code></td>
<td><code>kaldifeat.OnlinePlp</code></td>
<td>
See <a href="./kaldifeat/python/tests/test_plp.py">
./kaldifeat/python/tests/test_plp.py
</a>
</td>
</tr>
<tr>
<td>Spectorgram</td>
<td><code>kaldifeat.SpectrogramOptions</code></td>
<td><code>kaldifeat.Spectrogram</code></td>
<td>
<pre lang="python">
opts = kaldifeat.SpectrogramOptions();
print(opts)
spectrogram = kaldifeat.Spectrogram(opts)
features = spectrogram(wave)
</pre>
</td>
</tr>
</table>
</sub>
Feature extraction compatible with `Kaldi` using PyTorch, supporting
CUDA, batch processing, chunk processing, and autograd.
The following kaldi-compatible commandline tools are implemented:
- `compute-fbank-feats`
- `compute-mfcc-feats`
- `compute-plp-feats`
- `compute-spectrogram-feats`
(**NOTE**: We will implement other types of features, e.g., Pitch, ivector, etc, soon.)
**HINT**: It supports also streaming feature extractors for Fbank, MFCC, and Plp.
# Usage
Let us first generate a test wave using sox:
```bash
pip install kaldifeat
# generate a wave of 1.2 seconds, containing a sine-wave
# swept from 300 Hz to 3300 Hz
sox -n -r 16000 -b 16 test.wav synth 1.2 sine 300-3300
```
**HINT**: Download [test.wav][test_wav].
[test_wav]: kaldifeat/python/tests/test_data/test.wav
## Fbank
```python
import torchaudio
import kaldifeat
filename = "./test.wav"
wave, samp_freq = torchaudio.load(filename)
wave = wave.squeeze()
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0
# Yes, it has same options like `Kaldi`
fbank = kaldifeat.Fbank(opts)
features = fbank(wave)
```
To compute features that are compatible with `Kaldi`, wave samples have to be
scaled to the range `[-32768, 32768]`. **WARNING**: You don't have to do this if
you don't care about the compatibility with `Kaldi`.
The following is an example:
```python
wave *= 32768
fbank = kaldifeat.Fbank(opts)
features = fbank(wave)
print(features[:3])
```
The output is:
```
tensor([[15.0074, 21.1730, 25.5286, 24.4644, 16.6994, 13.8480, 11.2087, 11.7952,
10.3911, 10.4491, 10.3012, 9.8743, 9.6997, 9.3751, 9.3476, 9.3559,
9.1074, 9.0032, 9.0312, 8.8399, 9.0822, 8.7442, 8.4023],
[13.8785, 20.5647, 25.4956, 24.6966, 16.9541, 13.9163, 11.3364, 11.8449,
10.2565, 10.5871, 10.3484, 9.7474, 9.6123, 9.3964, 9.0695, 9.1177,
8.9136, 8.8425, 8.5920, 8.8315, 8.6226, 8.8605, 8.9763],
[13.9475, 19.9410, 25.4494, 24.9051, 17.0004, 13.9207, 11.6667, 11.8217,
10.3411, 10.7258, 10.0983, 9.8109, 9.6762, 9.4218, 9.1246, 8.7744,
9.0863, 8.7488, 8.4695, 8.6710, 8.7728, 8.7405, 8.9824]])
```
You can compute the fbank feature for the same wave with `Kaldi` using the following commands:
```bash
echo "1 test.wav" > test.scp
compute-fbank-feats --dither=0 scp:test.scp ark,t:test.txt
head -n4 test.txt
```
The output is:
```
1 [
15.00744 21.17303 25.52861 24.46438 16.69938 13.84804 11.2087 11.79517 10.3911 10.44909 10.30123 9.874329 9.699727 9.37509 9.347578 9.355928 9.107419 9.00323 9.031268 8.839916 9.082197 8.744139 8.40221
13.87853 20.56466 25.49562 24.69662 16.9541 13.91633 11.33638 11.84495 10.25656 10.58718 10.34841 9.747416 9.612316 9.39642 9.06955 9.117751 8.913527 8.842571 8.59212 8.831518 8.622513 8.86048 8.976251
13.94753 19.94101 25.4494 24.90511 17.00044 13.92074 11.66673 11.82172 10.34108 10.72575 10.09829 9.810879 9.676199 9.421767 9.124647 8.774353 9.086291 8.74897 8.469534 8.670973 8.772754 8.740549 8.982433
```
You can see that ``kaldifeat`` produces the same output as `Kaldi` (within some tolerance due to numerical precision).
**HINT**: Download [test.scp][test_scp] and [test.txt][test_txt].
[test_scp]: kaldifeat/python/tests/test_data/test.scp
[test_txt]: kaldifeat/python/tests/test_data/test.txt
To use GPU, you can use:
```python
import torch
opts = kaldifeat.FbankOptions()
opts.device = torch.device("cuda", 0)
fbank = kaldifeat.Fbank(opts)
features = fbank(wave.to(opts.device))
```
## MFCC, PLP, Spectrogram
To compute MFCC features, please replace `kaldifeat.FbankOptions` and `kaldifeat.Fbank`
with `kaldifeat.MfccOptions` and `kaldifeat.Mfcc`, respectively. The same goes
for `PLP` and `Spectrogram`.
Please refer to
- [kaldifeat/python/tests/test_fbank.py](kaldifeat/python/tests/test_fbank.py)
- [kaldifeat/python/tests/test_mfcc.py](kaldifeat/python/tests/test_mfcc.py)
- [kaldifeat/python/tests/test_plp.py](kaldifeat/python/tests/test_plp.py)
- [kaldifeat/python/tests/test_spectrogram.py](kaldifeat/python/tests/test_spectrogram.py)
- [kaldifeat/python/tests/test_frame_extraction_options.py](kaldifeat/python/tests/test_frame_extraction_options.py)
- [kaldifeat/python/tests/test_mel_bank_options.py](kaldifeat/python/tests/test_mel_bank_options.py)
- [kaldifeat/python/tests/test_fbank_options.py](kaldifeat/python/tests/test_fbank_options.py)
- [kaldifeat/python/tests/test_mfcc_options.py](kaldifeat/python/tests/test_mfcc_options.py)
- [kaldifeat/python/tests/test_spectrogram_options.py](kaldifeat/python/tests/test_spectrogram_options.py)
- [kaldifeat/python/tests/test_plp_options.py](kaldifeat/python/tests/test_plp_options.py)
for more examples.
**HINT**: In the examples, you can find that
- ``kaldifeat`` supports batch processing as well as chunk processing
- ``kaldifeat`` uses the same options as `Kaldi`'s `compute-fbank-feats` and `compute-mfcc-feats`
# Usage in other projects
## icefall
[icefall](https://github.com/k2-fsa/icefall) uses kaldifeat to extract features for a pre-trained model.
See <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/conformer_ctc/pretrained.py>.
## k2
[k2](https://github.com/k2-fsa/k2) uses kaldifeat's C++ API.
See <https://github.com/k2-fsa/k2/blob/v2.0-pre/k2/torch/csrc/features.cu>.
## lhotse
[lhotse](https://github.com/lhotse-speech/lhotse) uses kaldifeat to extract features on GPU.
See <https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/kaldifeat.py>.
## sherpa
[sherpa](https://github.com/k2-fsa/sherpa) uses kaldifeat for streaming feature
extraction.
See <https://github.com/k2-fsa/sherpa/blob/master/sherpa/bin/pruned_stateless_emformer_rnnt2/decode.py>

0
cmake/__init__.py Normal file
View File

138
cmake/cmake_extension.py Normal file
View File

@ -0,0 +1,138 @@
# Copyright (c) 2021 Xiaomi Corporation (author: Fangjun Kuang)
import glob
import os
import platform
import shutil
import sys
from pathlib import Path
import setuptools
import torch
from setuptools.command.build_ext import build_ext
def get_pytorch_version():
# if it is 1.7.1+cuda101, then strip +cuda101
return torch.__version__.split("+")[0]
def is_for_pypi():
ans = os.environ.get("KALDIFEAT_IS_FOR_PYPI", None)
return ans is not None
def is_macos():
return platform.system() == "Darwin"
def is_windows():
return platform.system() == "Windows"
try:
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
class bdist_wheel(_bdist_wheel):
def finalize_options(self):
_bdist_wheel.finalize_options(self)
# In this case, the generated wheel has a name in the form
# kaldifeat-xxx-pyxx-none-any.whl
if is_for_pypi() and not is_macos():
self.root_is_pure = True
else:
# The generated wheel has a name ending with
# -linux_x86_64.whl
self.root_is_pure = False
except ImportError:
bdist_wheel = None
def cmake_extension(name, *args, **kwargs) -> setuptools.Extension:
kwargs["language"] = "c++"
sources = []
return setuptools.Extension(name, sources, *args, **kwargs)
class BuildExtension(build_ext):
def build_extension(self, ext: setuptools.extension.Extension):
# build/temp.linux-x86_64-3.8
os.makedirs(self.build_temp, exist_ok=True)
# build/lib.linux-x86_64-3.8
os.makedirs(self.build_lib, exist_ok=True)
kaldifeat_dir = Path(__file__).parent.parent.resolve()
cmake_args = os.environ.get("KALDIFEAT_CMAKE_ARGS", "")
make_args = os.environ.get("KALDIFEAT_MAKE_ARGS", "")
system_make_args = os.environ.get("MAKEFLAGS", "")
if cmake_args == "":
cmake_args = "-DCMAKE_BUILD_TYPE=Release"
extra_cmake_args = " -Dkaldifeat_BUILD_TESTS=OFF "
extra_cmake_args += f" -DCMAKE_INSTALL_PREFIX={Path(self.build_lib).resolve()}/kaldifeat " # noqa
major, minor = get_pytorch_version().split(".")[:2]
print("major, minor", major, minor)
major = int(major)
minor = int(minor)
if major > 2 or (major == 2 and minor >= 1):
extra_cmake_args += f" -DCMAKE_CXX_STANDARD=17 "
if "PYTHON_EXECUTABLE" not in cmake_args:
print(f"Setting PYTHON_EXECUTABLE to {sys.executable}")
cmake_args += f" -DPYTHON_EXECUTABLE={sys.executable}"
cmake_args += extra_cmake_args
if is_windows():
build_cmd = f"""
cmake {cmake_args} -B {self.build_temp} -S {kaldifeat_dir}
cmake --build {self.build_temp} --target _kaldifeat --config Release -- -m
cmake --build {self.build_temp} --target install --config Release -- -m
"""
print(f"build command is:\n{build_cmd}")
ret = os.system(
f"cmake {cmake_args} -B {self.build_temp} -S {kaldifeat_dir}"
)
if ret != 0:
raise Exception("Failed to configure kaldifeat")
ret = os.system(
f"cmake --build {self.build_temp} --target _kaldifeat --config Release -- -m"
)
if ret != 0:
raise Exception("Failed to build kaldifeat")
ret = os.system(
f"cmake --build {self.build_temp} --target install --config Release -- -m"
)
if ret != 0:
raise Exception("Failed to install kaldifeat")
else:
if make_args == "" and system_make_args == "":
print("For fast compilation, run:")
print('export KALDIFEAT_MAKE_ARGS="-j"; python setup.py install')
make_args = " -j4 "
print("Setting make_args to '-j4'")
build_cmd = f"""
cd {self.build_temp}
cmake {cmake_args} {kaldifeat_dir}
make {make_args} _kaldifeat install
"""
print(f"build command is:\n{build_cmd}")
ret = os.system(build_cmd)
if ret != 0:
raise Exception(
"\nBuild kaldifeat failed. Please check the error message.\n"
"You can ask for help by creating an issue on GitHub.\n"
"\nClick:\n\thttps://github.com/csukuangfj/kaldifeat/issues/new\n" # noqa
)

98
cmake/googletest.cmake Normal file
View File

@ -0,0 +1,98 @@
# Copyright 2020 Fangjun Kuang (csukuangfj@gmail.com)
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function(download_googltest)
if(CMAKE_VERSION VERSION_LESS 3.11)
# FetchContent is available since 3.11,
# we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules
# so that it can be used in lower CMake versions.
message(STATUS "Use FetchContent provided by kaldifeat")
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
endif()
include(FetchContent)
set(googletest_URL "https://github.com/google/googletest/archive/refs/tags/v1.13.0.tar.gz")
set(googletest_URL2 "https://huggingface.co/csukuangfj/k2-cmake-deps/resolve/main/googletest-1.13.0.tar.gz")
set(googletest_HASH "SHA256=ad7fdba11ea011c1d925b3289cf4af2c66a352e18d4c7264392fead75e919363")
# If you don't have access to the Internet,
# please pre-download googletest
set(possible_file_locations
$ENV{HOME}/Downloads/googletest-1.13.0.tar.gz
${PROJECT_SOURCE_DIR}/googletest-1.13.0.tar.gz
${PROJECT_BINARY_DIR}/googletest-1.13.0.tar.gz
/tmp/googletest-1.13.0.tar.gz
/star-fj/fangjun/download/github/googletest-1.13.0.tar.gz
)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(googletest_URL "${f}")
file(TO_CMAKE_PATH "${googletest_URL}" googletest_URL)
set(googletest_URL2)
break()
endif()
endforeach()
set(BUILD_GMOCK ON CACHE BOOL "" FORCE)
set(INSTALL_GTEST OFF CACHE BOOL "" FORCE)
set(gtest_disable_pthreads ON CACHE BOOL "" FORCE)
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
FetchContent_Declare(googletest
URL
${googletest_URL}
${googletest_URL2}
URL_HASH ${googletest_HASH}
)
FetchContent_GetProperties(googletest)
if(NOT googletest_POPULATED)
message(STATUS "Downloading googletest from ${googletest_URL}")
FetchContent_Populate(googletest)
endif()
message(STATUS "googletest is downloaded to ${googletest_SOURCE_DIR}")
message(STATUS "googletest's binary dir is ${googletest_BINARY_DIR}")
if(APPLE)
set(CMAKE_MACOSX_RPATH ON) # to solve the following warning on macOS
endif()
#[==[
-- Generating done
Policy CMP0042 is not set: MACOSX_RPATH is enabled by default. Run "cmake
--help-policy CMP0042" for policy details. Use the cmake_policy command to
set the policy and suppress this warning.
MACOSX_RPATH is not specified for the following targets:
gmock
gmock_main
gtest
gtest_main
This warning is for project developers. Use -Wno-dev to suppress it.
]==]
add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL)
target_include_directories(gtest
INTERFACE
${googletest_SOURCE_DIR}/googletest/include
${googletest_SOURCE_DIR}/googlemock/include
)
endfunction()
download_googltest()

View File

@ -0,0 +1,65 @@
# Findkaldifeat
# -------------
#
# Finds the kaldifeat library
#
# This will define the following variables:
#
# KALDIFEAT_FOUND -- True if the system has the kaldifeat library
# KALDIFEAT_INCLUDE_DIRS -- The include directories for kaldifeat
# KALDIFEAT_LIBRARIES -- Libraries to link against
# KALDIFEAT_CXX_FLAGS -- Additional (required) compiler flags
# KALDIFEAT_TORCH_VERSION_MAJOR -- The major version of PyTorch used to compile kaldifeat
# KALDIFEAT_TORCH_VERSION_MINOR -- The minor version of PyTorch used to compile kaldifeat
# KALDIFEAT_VERSION -- The version of kaldifeat
#
# and the following imported targets:
#
# kaldifeat_core
# This file is modified from pytorch/cmake/TorchConfig.cmake.in
set(KALDIFEAT_CXX_FLAGS "@CMAKE_CXX_FLAGS@")
set(KALDIFEAT_TORCH_VERSION_MAJOR @KALDIFEAT_TORCH_VERSION_MAJOR@)
set(KALDIFEAT_TORCH_VERSION_MINOR @KALDIFEAT_TORCH_VERSION_MINOR@)
set(KALDIFEAT_VERSION @kaldifeat_VERSION@)
if(DEFINED ENV{KALDIFEAT_INSTALL_PREFIX})
set(KALDIFEAT_INSTALL_PREFIX $ENV{KALDIFEAT_INSTALL_PREFIX})
else()
# Assume we are in <install-prefix>/share/cmake/kaldifeat/kaldifeatConfig.cmake
get_filename_component(CMAKE_CURRENT_LIST_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)
get_filename_component(KALDIFEAT_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE)
endif()
set(KALDIFEAT_INCLUDE_DIRS ${KALDIFEAT_INSTALL_PREFIX}/include)
set(KALDIFEAT_LIBRARIES kaldifeat_core)
foreach(lib IN LISTS KALDIFEAT_LIBRARIES)
find_library(location_${lib} ${lib}
PATHS
"${KALDIFEAT_INSTALL_PREFIX}/lib"
"${KALDIFEAT_INSTALL_PREFIX}/lib64"
)
if(NOT MSVC)
add_library(${lib} SHARED IMPORTED)
else()
add_library(${lib} STATIC IMPORTED)
endif()
set_target_properties(${lib} PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${KALDIFEAT_INCLUDE_DIRS}"
IMPORTED_LOCATION "${location_${lib}}"
CXX_STANDARD 14
)
set_property(TARGET ${lib} PROPERTY INTERFACE_COMPILE_OPTIONS @CMAKE_CXX_FLAGS@)
endforeach()
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(kaldifeat DEFAULT_MSG
location_kaldifeat_core
)

View File

@ -0,0 +1,12 @@
# This file is modified from pytorch/cmake/TorchConfigVersion.cmake.in
set(PACKAGE_VERSION "@kaldifeat_VERSION@")
# Check whether the requested PACKAGE_FIND_VERSION is compatible
if("${PACKAGE_VERSION}" VERSION_LESS "${PACKAGE_FIND_VERSION}")
set(PACKAGE_VERSION_COMPATIBLE FALSE)
else()
set(PACKAGE_VERSION_COMPATIBLE TRUE)
if("${PACKAGE_VERSION}" VERSION_EQUAL "${PACKAGE_FIND_VERSION}")
set(PACKAGE_VERSION_EXACT TRUE)
endif()
endif()

View File

@ -8,23 +8,39 @@ function(download_pybind11)
include(FetchContent)
set(pybind11_URL "https://github.com/pybind/pybind11/archive/v2.6.0.tar.gz")
set(pybind11_HASH "SHA256=90b705137b69ee3b5fc655eaca66d0dc9862ea1759226f7ccd3098425ae69571")
set(pybind11_URL "https://github.com/pybind/pybind11/archive/refs/tags/v2.12.0.tar.gz")
set(pybind11_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/pybind11-2.12.0.tar.gz")
set(pybind11_HASH "SHA256=bf8f242abd1abcd375d516a7067490fb71abd79519a282d22b6e4d19282185a7")
# If you don't have access to the Internet,
# please pre-download pybind11
set(possible_file_locations
$ENV{HOME}/Downloads/pybind11-2.12.0.tar.gz
${CMAKE_SOURCE_DIR}/pybind11-2.12.0.tar.gz
${CMAKE_BINARY_DIR}/pybind11-2.12.0.tar.gz
/tmp/pybind11-2.12.0.tar.gz
/star-fj/fangjun/download/github/pybind11-2.12.0.tar.gz
)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(pybind11_URL "${f}")
file(TO_CMAKE_PATH "${pybind11_URL}" pybind11_URL)
set(pybind11_URL2)
break()
endif()
endforeach()
set(double_quotes "\"")
set(dollar "\$")
set(semicolon "\;")
FetchContent_Declare(pybind11
URL ${pybind11_URL}
URL
${pybind11_URL}
${pybind11_URL2}
URL_HASH ${pybind11_HASH}
PATCH_COMMAND
sed -i s/\\${double_quotes}-flto\\\\${dollar}/\\${double_quotes}-Xcompiler=-flto${dollar}/g "tools/pybind11Tools.cmake" &&
sed -i s/${seimcolon}-fno-fat-lto-objects/${seimcolon}-Xcompiler=-fno-fat-lto-objects/g "tools/pybind11Tools.cmake"
)
FetchContent_GetProperties(pybind11)
if(NOT pybind11_POPULATED)
message(STATUS "Downloading pybind11")
message(STATUS "Downloading pybind11 from ${pybind11_URL}")
FetchContent_Populate(pybind11)
endif()
message(STATUS "pybind11 is downloaded to ${pybind11_SOURCE_DIR}")

View File

@ -8,6 +8,7 @@ execute_process(
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE TORCH_DIR
)
message(STATUS "TORCH_DIR: ${TORCH_DIR}")
list(APPEND CMAKE_PREFIX_PATH "${TORCH_DIR}")
find_package(Torch REQUIRED)
@ -24,16 +25,14 @@ execute_process(
message(STATUS "PyTorch version: ${TORCH_VERSION}")
# Solve the following error for NVCC:
# unknown option `-Wall`
#
# It contains only some -Wno-* flags, so it is OK
# to set them to empty
set_property(TARGET torch_cuda
PROPERTY
INTERFACE_COMPILE_OPTIONS ""
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__.split('.')[0])"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE KALDIFEAT_TORCH_VERSION_MAJOR
)
set_property(TARGET torch_cpu
PROPERTY
INTERFACE_COMPILE_OPTIONS ""
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__.split('.')[1])"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE KALDIFEAT_TORCH_VERSION_MINOR
)

20
doc/Makefile Normal file
View File

@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

35
doc/make.bat Normal file
View File

@ -0,0 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

6
doc/requirements.txt Normal file
View File

@ -0,0 +1,6 @@
dataclasses
recommonmark
sphinx<7.0
sphinx-autodoc-typehints
sphinx_rtd_theme
sphinxcontrib-bibtex

View File

136
doc/source/conf.py Normal file
View File

@ -0,0 +1,136 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
import re
import sphinx_rtd_theme
# import sys
# sys.path.insert(0, os.path.abspath('.'))
# -- Project information -----------------------------------------------------
project = "kaldifeat"
copyright = "2021, Fangjun Kuang"
author = "Fangjun Kuang"
def get_version():
cmake_file = "../../CMakeLists.txt"
with open(cmake_file) as f:
content = f.read()
version = re.search(r"set\(kaldifeat_VERSION (.*)\)", content).group(1)
return version.strip('"')
version = get_version()
release = version
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
"recommonmark",
"sphinx.ext.autodoc",
"sphinx.ext.githubpages",
"sphinx.ext.napoleon",
"sphinx_autodoc_typehints",
"sphinx_rtd_theme",
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ["images/*.md"]
source_suffix = {
".rst": "restructuredtext",
".md": "markdown",
}
master_doc = "index"
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = "sphinx_rtd_theme"
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
html_show_sourcelink = True
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
pygments_style = "sphinx"
numfig = True
html_context = {
"display_github": True,
"github_user": "csukuangfj",
"github_repo": "kaldifeat",
"github_version": "master",
"conf_py_path": "/kaldifeat/docs/source/",
}
# refer to
# https://sphinx-rtd-theme.readthedocs.io/en/latest/configuring.html
html_theme_options = {
"logo_only": False,
"display_version": True,
"prev_next_buttons_location": "bottom",
"style_external_links": True,
}
rst_epilog = """
.. _kaldifeat: https://github.com/csukuangfj/kaldifeat
.. _Kaldi: https://github.com/kaldi-asr/kaldi
.. _PyTorch: https://pytorch.org/
.. _kaldifeat.Fbank: https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/python/kaldifeat/fbank.py#L10
.. _kaldifeat.Mfcc: https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/python/kaldifeat/mfcc.py#L10
.. _kaldifeat.Plp: https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/python/kaldifeat/plp.py#L10
.. _kaldifeat.Spectrogram: https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/python/kaldifeat/spectrogram.py#L9
.. _kaldifeat.OnlineFbank: https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/python/kaldifeat/fbank.py#L16
.. _kaldifeat.OnlineMfcc: https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/python/kaldifeat/mfcc.py#L16
.. _kaldifeat.OnlinePlp: https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/python/kaldifeat/plp.py#L16
.. _compute-fbank-feats: https://github.com/kaldi-asr/kaldi/blob/master/src/featbin/compute-fbank-feats.cc
.. _compute-mfcc-feats: https://github.com/kaldi-asr/kaldi/blob/master/src/featbin/compute-mfcc-feats.cc
.. _compute-plp-feats: https://github.com/kaldi-asr/kaldi/blob/master/src/featbin/compute-plp-feats.cc
.. _compute-spectrogram-feats: https://github.com/kaldi-asr/kaldi/blob/master/src/featbin/compute-spectrogram-feats.cc
.. _kaldi::OnlineFbank: https://github.com/kaldi-asr/kaldi/blob/master/src/feat/online-feature.h#L160
.. _kaldi::OnlineMfcc: https://github.com/kaldi-asr/kaldi/blob/master/src/feat/online-feature.h#L158
.. _kaldi::OnlinePlp: https://github.com/kaldi-asr/kaldi/blob/master/src/feat/online-feature.h#L159
.. _kaldifeat.FbankOptions: https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/csrc/feature-fbank.h#L19
.. _kaldi::FbankOptions: https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-fbank.h#L41
.. _kaldifeat.MfccOptions: https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/csrc/feature-mfcc.h#L22
.. _kaldi::MfccOptions: https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-mfcc.h#L38
.. _kaldifeat.PlpOptions: https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/csrc/feature-plp.h#L24
.. _kaldi::PlpOptions: https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-plp.h#L42
.. _kaldifeat.SpectrogramOptions: https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/csrc/feature-spectrogram.h#L18
.. _kaldi::SpectrogramOptions: https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-spectrogram.h#L38
.. _kaldifeat.FrameExtractionOptions: https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/csrc/feature-window.h#L30
.. _kaldi::FrameExtractionOptions: https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-window.h#L35
.. _kaldifeat.MelBanksOptions: https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/csrc/mel-computations.h#L17
.. _kaldi::MelBanksOptions: https://github.com/kaldi-asr/kaldi/blob/master/src/feat/mel-computations.h#L43
"""

View File

@ -0,0 +1,8 @@
## File description
<https://shields.io/> is used to create the following files:
- ./os.svg
- ./python_ge_3.6-blue.svg
- ./cuda_ge_10.1-orange.svg
- ./pytorch_ge_1.5.0-green.svg

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="94" height="20" role="img" aria-label="cuda: &gt;= 10.1"><title>cuda: &gt;= 10.1</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="94" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="37" height="20" fill="#555"/><rect x="37" width="57" height="20" fill="#fe7d37"/><rect width="94" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="195" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="270">cuda</text><text x="195" y="140" transform="scale(.1)" fill="#fff" textLength="270">cuda</text><text aria-hidden="true" x="645" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="470">&gt;= 10.1</text><text x="645" y="140" transform="scale(.1)" fill="#fff" textLength="470">&gt;= 10.1</text></g></svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="176" height="20" role="img" aria-label="os: Linux | macOS | Windows"><title>os: Linux | macOS | Windows</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="176" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="23" height="20" fill="#555"/><rect x="23" width="153" height="20" fill="#97ca00"/><rect width="176" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="125" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="130">os</text><text x="125" y="140" transform="scale(.1)" fill="#fff" textLength="130">os</text><text aria-hidden="true" x="985" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="1430">Linux | macOS | Windows</text><text x="985" y="140" transform="scale(.1)" fill="#fff" textLength="1430">Linux | macOS | Windows</text></g></svg>

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="98" height="20" role="img" aria-label="python: &gt;= 3.6"><title>python: &gt;= 3.6</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="98" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="49" height="20" fill="#555"/><rect x="49" width="49" height="20" fill="#007ec6"/><rect width="98" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="255" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="390">python</text><text x="255" y="140" transform="scale(.1)" fill="#fff" textLength="390">python</text><text aria-hidden="true" x="725" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="390">&gt;= 3.6</text><text x="725" y="140" transform="scale(.1)" fill="#fff" textLength="390">&gt;= 3.6</text></g></svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="112" height="20" role="img" aria-label="pytorch: &gt;= 1.5.0"><title>pytorch: &gt;= 1.5.0</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="112" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="51" height="20" fill="#555"/><rect x="51" width="61" height="20" fill="#97ca00"/><rect width="112" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="265" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="410">pytorch</text><text x="265" y="140" transform="scale(.1)" fill="#fff" textLength="410">pytorch</text><text aria-hidden="true" x="805" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="510">&gt;= 1.5.0</text><text x="805" y="140" transform="scale(.1)" fill="#fff" textLength="510">&gt;= 1.5.0</text></g></svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

16
doc/source/index.rst Normal file
View File

@ -0,0 +1,16 @@
.. kaldifeat documentation master file, created by
sphinx-quickstart on Fri Jul 16 20:15:27 2021.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
kaldifeat
=========
.. toctree::
:maxdepth: 2
:caption: Contents
intro
installation/index
usage/index

View File

@ -0,0 +1,48 @@
FAQs
====
How to install a CUDA version of kaldifeat from source
------------------------------------------------------
You need to first install a CUDA version of `PyTorch`_ and then install `kaldifeat`_.
.. note::
You can use a CUDA version of `kaldifeat`_ on machines with no GPUs.
How to install a CPU version of kaldifeat from source
-----------------------------------------------------
You need to first install a CPU version of `PyTorch`_ and then install `kaldifeat`_.
How to fix `Caffe2: Cannot find cuDNN library`
----------------------------------------------
.. code-block::
Your installed Caffe2 version uses cuDNN but I cannot find the cuDNN
libraries. Please set the proper cuDNN prefixes and / or install cuDNN.
You will have such an error when you want to install a CUDA version of `kaldifeat`_
by ``pip install kaldifeat`` or from source.
You need to first install cuDNN. Assume you have installed cuDNN to the
path ``/path/to/cudnn``. You can fix the error by using ``one`` of the following
commands.
(1) Fix for installation using ``pip install``
.. code-block:: bash
export KALDIFEAT_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=Release -DCUDNN_LIBRARY_PATH=/path/to/cudnn/lib/libcudnn.so -DCUDNN_INCLUDE_PATH=/path/to/cudnn/include"
pip install --verbose kaldifeat
(2) Fix for installation from source
.. code-block:: bash
mkdir /some/path
git clone https://github.com/csukuangfj/kaldifeat.git
cd kaldifeat
export KALDIFEAT_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=Release -DCUDNN_LIBRARY_PATH=/path/to/cudnn/lib/libcudnn.so -DCUDNN_INCLUDE_PATH=/path/to/cudnn/include"
python setup.py install

View File

@ -0,0 +1,47 @@
.. _from source:
Install kaldifeat from source
=============================
You have to install ``cmake`` and `PyTorch`_ first.
- ``cmake`` 3.11 is known to work. Other CMake versions may also work.
- `PyTorch`_ >= 1.5.0 is known to work. Other PyTorch versions may also work.
- Python >= 3.6
- A compiler that supports C++ 14
The commands to install `kaldifeat`_ from source are:
.. code-block:: bash
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
python3 setup.py install
To test that you have installed `kaldifeat`_ successfully, please run:
.. code-block:: bash
python3 -c "import kaldifeat; print(kaldifeat.__version__)"
It should print the version, e.g., ``1.0``.
.. _from PyPI:
Install kaldifeat from PyPI
---------------------------
The command to install `kaldifeat`_ from PyPI is:
.. code-block:: bash
pip install --verbose kaldifeat
To test that you have installed `kaldifeat`_ successfully, please run:
.. code-block:: bash
python3 -c "import kaldifeat; print(kaldifeat.__version__)"
It should print the version, e.g., ``1.0``.

View File

@ -0,0 +1,139 @@
From pre-compiled wheels (Recommended)
=======================================
You can find pre-compiled wheels at
- CPU wheels: `<https://csukuangfj.github.io/kaldifeat/cpu.html>`_
- CUDA wheels: `<https://csukuangfj.github.io/kaldifeat/cuda.html>`_
We give a few examples below to show you how to install `kaldifeat`_ from
pre-compiled wheels.
.. hint::
The following lists only some examples. We suggest that you always select the
latest version of ``kaldifeat``.
Linux (CPU)
-----------
Suppose you want to install the following wheel:
.. code-block:: bash
https://huggingface.co/csukuangfj/kaldifeat/resolve/main/ubuntu-cpu/kaldifeat-1.25.4.dev20240725+cpu.torch2.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
you can use one of the following methods:
.. code-block:: bash
# method 1
pip install torch==2.4.0+cpu -f https://download.pytorch.org/whl/torch/
pip install kaldifeat==1.25.4.dev20240725+cpu.torch2.4.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
# For users from China
# 中国国内用户,如果访问不了 huggingface, 请使用
# pip install kaldifeat==1.25.4.dev20240725+cpu.torch2.4.0 -f https://csukuangfj.github.io/kaldifeat/cpu-cn.html
# method 2
pip install torch==2.4.0+cpu -f https://download.pytorch.org/whl/torch/
wget https://huggingface.co/csukuangfj/kaldifeat/resolve/main/ubuntu-cpu/kaldifeat-1.25.4.dev20240725+cpu.torch2.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
# For users from China
# 中国国内用户,如果访问不了 huggingface, 请使用
# wget https://hf-mirror.com/csukuangfj/kaldifeat/resolve/main/ubuntu-cpu/kaldifeat-1.25.4.dev20240725+cpu.torch2.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ./kaldifeat-1.25.4.dev20240725+cpu.torch2.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Windows (CPU)
--------------
Suppose you want to install the following wheel:
.. code-block:: bash
https://huggingface.co/csukuangfj/kaldifeat/resolve/main/windows-cpu/kaldifeat-1.25.4.dev20240725+cpu.torch2.4.0-cp312-cp312-win_amd64.whl
you can use one of the following methods:
.. code-block:: bash
# method 1
pip install torch==2.4.0+cpu -f https://download.pytorch.org/whl/torch/
pip install kaldifeat==1.25.4.dev20240725+cpu.torch2.4.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
# For users from China
# 中国国内用户,如果访问不了 huggingface, 请使用
# pip install kaldifeat==1.25.4.dev20240725+cpu.torch2.4.0 -f https://csukuangfj.github.io/kaldifeat/cpu-cn.html
# method 2
pip install torch==2.4.0+cpu -f https://download.pytorch.org/whl/torch/
wget https://huggingface.co/csukuangfj/kaldifeat/resolve/main/windows-cpu/kaldifeat-1.25.4.dev20240725+cpu.torch2.4.0-cp312-cp312-win_amd64.whl
# For users from China
# 中国国内用户,如果访问不了 huggingface, 请使用
# wget https://hf-mirror.com/csukuangfj/kaldifeat/resolve/main/windows-cpu/kaldifeat-1.25.4.dev20240725+cpu.torch2.4.0-cp312-cp312-win_amd64.whl
pip install ./kaldifeat-1.25.4.dev20240725+cpu.torch2.4.0-cp312-cp312-win_amd64.whl
macOS (CPU)
-----------
Suppose you want to install the following wheel:
.. code-block:: bash
https://huggingface.co/csukuangfj/kaldifeat/resolve/main/macos/kaldifeat-1.25.4.dev20240725+cpu.torch2.4.0-cp311-cp311-macosx_11_0_arm64.whl
you can use one of the following methods:
.. code-block:: bash
# method 1
pip install torch==2.4.0
pip install kaldifeat==1.25.4.dev20240725+cpu.torch2.4.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
# For users from China
# 中国国内用户,如果访问不了 huggingface, 请使用
# pip install kaldifeat==1.25.4.dev20240725+cpu.torch2.4.0 -f https://csukuangfj.github.io/kaldifeat/cpu-cn.html
# method 2
pip install torch==2.4.0 -f https://download.pytorch.org/whl/torch/
wget https://huggingface.co/csukuangfj/kaldifeat/resolve/main/macos/kaldifeat-1.25.4.dev20240725+cpu.torch2.4.0-cp311-cp311-macosx_11_0_arm64.whl
# For users from China
# 中国国内用户,如果访问不了 huggingface, 请使用
# wget https://hf-mirror.com/csukuangfj/kaldifeat/resolve/main/macos/kaldifeat-1.25.4.dev20240725+cpu.torch2.4.0-cp311-cp311-macosx_11_0_arm64.whl
pip install ./kaldifeat-1.25.4.dev20240725+cpu.torch2.4.0-cp311-cp311-macosx_11_0_arm64.whl
Linux (CUDA)
------------
Suppose you want to install the following wheel:
.. code-block:: bash
https://huggingface.co/csukuangfj/kaldifeat/resolve/main/ubuntu-cuda/kaldifeat-1.25.4.dev20240725+cuda12.4.torch2.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
you can use one of the following methods:
.. code-block:: bash
# method 1
pip install torch==2.4.0+cu124 -f https://download.pytorch.org/whl/torch/
pip install kaldifeat==1.25.4.dev20240725+cuda12.4.torch2.4.0 -f https://csukuangfj.github.io/kaldifeat/cuda.html
# For users from China
# 中国国内用户,如果访问不了 huggingface, 请使用
# pip install kaldifeat==1.25.4.dev20240725+cuda12.4.torch2.4.0 -f https://csukuangfj.github.io/kaldifeat/cuda-cn.html
# method 2
pip install torch==2.4.0+cu124 -f https://download.pytorch.org/whl/torch/
wget https://huggingface.co/csukuangfj/kaldifeat/resolve/main/ubuntu-cuda/kaldifeat-1.25.4.dev20240725+cuda12.4.torch2.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
# For users from China
# 中国国内用户,如果访问不了 huggingface, 请使用
# wget https://hf-mirror.com/csukuangfj/kaldifeat/resolve/main/ubuntu-cuda/kaldifeat-1.25.4.dev20240725+cuda12.4.torch2.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ./kaldifeat-1.25.4.dev20240725+cuda12.4.torch2.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

View File

@ -0,0 +1,11 @@
Installation
============
.. toctree::
:maxdepth: 3
./from_wheels.rst
./from_source.rst
./faq.rst

103
doc/source/intro.rst Normal file
View File

@ -0,0 +1,103 @@
Introduction
============
`kaldifeat`_ implements
speech feature extraction algorithms **compatible** with `Kaldi`_ using `PyTorch`_,
supporting CUDA as well as autograd.
`kaldifeat`_ has the following features:
- Fully compatible with `Kaldi`_
.. note::
The underlying C++ code is copied & modified from `Kaldi`_ directly.
It is rewritten with `PyTorch` C++ APIs.
- Provide not only ``C++ APIs`` but also ``Python APIs``
.. note::
You can access `kaldifeat`_ from ``Python``.
- Support autograd
- Support ``CUDA`` and ``CPU``
.. note::
You can use CUDA for feature extraction.
- Support ``online`` (i.e., ``streaming``) and ``offline`` (i.e., ``non-streaming``)
feature extraction
- Support chunk-based processing
.. note::
This is especially usefull if you want to process audios of several
hours long, which may cause OOM if you send them for computation at once.
With chunk-based processing, you can process audios of arbirtray length.
- Support batch processing
.. note::
With `kaldifeat`_ you can extract features for a batch of audios
.. see https://sublime-and-sphinx-guide.readthedocs.io/en/latest/tables.html
Currently implemented speech features and their counterparts in `Kaldi`_ are
listed in the following table.
.. list-table:: Supported speech features
:widths: 50 50
:header-rows: 1
* - Supported speech features
- Counterpart in `Kaldi`_
* - `kaldifeat.Fbank`_
- `compute-fbank-feats`_
* - `kaldifeat.Mfcc`_
- `compute-mfcc-feats`_
* - `kaldifeat.Plp`_
- `compute-plp-feats`_
* - `kaldifeat.Spectrogram`_
- `compute-spectrogram-feats`_
* - `kaldifeat.OnlineFbank`_
- `kaldi::OnlineFbank`_
* - `kaldifeat.OnlineMfcc`_
- `kaldi::OnlineMfcc`_
* - `kaldifeat.OnlinePlp`_
- `kaldi::OnlinePlp`_
Each feature computer needs an option. The following table lists the options
for each computer and the corresponding options in `Kaldi`_.
.. hint::
Note that we reuse the parameter names from `Kaldi`_.
Also, both online feature computers and offline feature computers share the
same option.
.. list-table:: Feature computer options
:widths: 50 50
:header-rows: 1
* - Options in `kaldifeat`_
- Corresponding options in `Kaldi`_
* - `kaldifeat.FbankOptions`_
- `kaldi::FbankOptions`_
* - `kaldifeat.MfccOptions`_
- `kaldi::MfccOptions`_
* - `kaldifeat.PlpOptions`_
- `kaldi::PlpOptions`_
* - `kaldifeat.SpectrogramOptions`_
- `kaldi::SpectrogramOptions`_
* - `kaldifeat.FrameExtractionOptions`_
- `kaldi::FrameExtractionOptions`_
* - `kaldifeat.MelBanksOptions`_
- `kaldi::MelBanksOptions`_
Read more to learn how to install `kaldifeat`_ and how to use each feature
computer.

View File

@ -0,0 +1,46 @@
compute-fbank-feats
Create Mel-filter bank (FBANK) feature files.
Usage: compute-fbank-feats [options...] <wav-rspecifier> <feats-wspecifier>
Options:
--allow-downsample : If true, allow the input waveform to have a higher frequency than the specified --sample-frequency (and we'll downsample). (bool, default = false)
--allow-upsample : If true, allow the input waveform to have a lower frequency than the specified --sample-frequency (and we'll upsample). (bool, default = false)
--blackman-coeff : Constant coefficient for generalized Blackman window. (float, default = 0.42)
--channel : Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (int, default = -1)
--debug-mel : Print out debugging information for mel bin computation (bool, default = false)
--dither : Dithering constant (0.0 means no dither). If you turn this off, you should set the --energy-floor option, e.g. to 1.0 or 0.1 (float, default = 1)
--energy-floor : Floor on energy (absolute, not relative) in FBANK computation. Only makes a difference if --use-energy=true; only necessary if --dither=0.0. Suggested values: 0.1 or 1.0 (float, default = 0)
--frame-length : Frame length in milliseconds (float, default = 25)
--frame-shift : Frame shift in milliseconds (float, default = 10)
--high-freq : High cutoff frequency for mel bins (if <= 0, offset from Nyquist) (float, default = 0)
--htk-compat : If true, put energy last. Warning: not sufficient to get HTK compatible features (need to change other parameters). (bool, default = false)
--low-freq : Low cutoff frequency for mel bins (float, default = 20)
--max-feature-vectors : Memory optimization. If larger than 0, periodically remove feature vectors so that only this number of the latest feature vectors is retained. (int, default = -1)
--min-duration : Minimum duration of segments to process (in seconds). (float, default = 0)
--num-mel-bins : Number of triangular mel-frequency bins (int, default = 23)
--output-format : Format of the output files [kaldi, htk] (string, default = "kaldi")
--preemphasis-coefficient : Coefficient for use in signal preemphasis (float, default = 0.97)
--raw-energy : If true, compute energy before preemphasis and windowing (bool, default = true)
--remove-dc-offset : Subtract mean from waveform on each frame (bool, default = true)
--round-to-power-of-two : If true, round window size to power of two by zero-padding input to FFT. (bool, default = true)
--sample-frequency : Waveform data sample frequency (must match the waveform file, if specified there) (float, default = 16000)
--snip-edges : If true, end effects will be handled by outputting only frames that completely fit in the file, and the number of frames depends on the frame-length. If false, the number of frames depends only on the frame-shift, and we reflect the data at the ends. (bool, default = true)
--subtract-mean : Subtract mean of each feature file [CMS]; not recommended to do it this way. (bool, default = false)
--use-energy : Add an extra dimension with energy to the FBANK output. (bool, default = false)
--use-log-fbank : If true, produce log-filterbank, else produce linear. (bool, default = true)
--use-power : If true, use power, else use magnitude. (bool, default = true)
--utt2spk : Utterance to speaker-id map (if doing VTLN and you have warps per speaker) (string, default = "")
--vtln-high : High inflection point in piecewise linear VTLN warping function (if negative, offset from high-mel-freq (float, default = -500)
--vtln-low : Low inflection point in piecewise linear VTLN warping function (float, default = 100)
--vtln-map : Map from utterance or speaker-id to vtln warp factor (rspecifier) (string, default = "")
--vtln-warp : Vtln warp factor (only applicable if vtln-map not specified) (float, default = 1)
--window-type : Type of window ("hamming"|"hanning"|"povey"|"rectangular"|"sine"|"blackmann") (string, default = "povey")
--write-utt2dur : Wspecifier to write duration of each utterance in seconds, e.g. 'ark,t:utt2dur'. (string, default = "")
Standard options:
--config : Configuration file to read (this option may be repeated) (string, default = "")
--help : Print out usage message (bool, default = false)
--print-args : Print the command line arguments (to stderr) (bool, default = true)
--verbose : Verbose level (higher->more logging) (int, default = 0)

View File

@ -0,0 +1,65 @@
$ python3
Python 3.8.0 (default, Oct 28 2019, 16:14:01)
[GCC 8.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import kaldifeat
>>> opts = kaldifeat.FbankOptions()
>>> print(opts)
frame_opts:
samp_freq: 16000
frame_shift_ms: 10
frame_length_ms: 25
dither: 1
preemph_coeff: 0.97
remove_dc_offset: 1
window_type: povey
round_to_power_of_two: 1
blackman_coeff: 0.42
snip_edges: 1
max_feature_vectors: -1
mel_opts:
num_bins: 23
low_freq: 20
high_freq: 0
vtln_low: 100
vtln_high: -500
debug_mel: 0
htk_mode: 0
use_energy: 0
energy_floor: 0
raw_energy: 1
htk_compat: 0
use_log_fbank: 1
use_power: 1
device: cpu
>>> print(opts.dither)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: '_kaldifeat.FbankOptions' object has no attribute 'dither'
>>>
>>> print(opts.frame_opts.dither)
1.0
>>> opts.frame_opts.dither = 0 # disable dither
>>> print(opts.frame_opts.dither)
0.0
>>> import torch
>>> print(opts.device)
cpu
>>> opts.device = 'cuda:0'
>>> print(opts.device)
cuda:0
>>> opts.device = torch.device('cuda', 1)
>>> print(opts.device)
cuda:1
>>> opts.device = 'cpu'
>>> print(opts.device)
cpu
>>> print(opts.mel_opts.num_bins)
23
>>> opts.mel_opts.num_bins = 80
>>> print(opts.mel_opts.num_bins)
80

View File

@ -0,0 +1 @@
../../../../kaldifeat/python/tests/test_fbank_options.py

View File

@ -0,0 +1,3 @@
kaldifeat.Fbank
===============

View File

@ -0,0 +1,51 @@
kaldifeat.FbankOptions
======================
If you want to construct an instance of `kaldifeat.Fbank`_ or
`kaldifeat.OnlineFbank`_, you have to provide an instance of
`kaldifeat.FbankOptions`_.
The following code shows how to construct an instance of `kaldifeat.FbankOptions`_.
.. literalinclude:: ./code/fbank_options-1.txt
:caption: Usage of `kaldifeat.FbankOptions`_
:emphasize-lines: 6,8,22,37
Note that we reuse the same option name with `compute-fbank-feats`_ from `Kaldi`_:
.. code-block:: bash
$ compute-fbank-feats --help
.. literalinclude:: ./code/compute-fbank-feats-help.txt
:caption: Output of ``compute-fbank-feats --help``
Please refer to the output of ``compute-fbank-feats --help`` for the meaning
of each field of `kaldifeat.FbankOptions`_.
One thing worth noting is that `kaldifeat.FbankOptions`_ has a field ``device``,
which is an instance of ``torch.device``. You can assign it either a string, e.g.,
``"cpu"`` or ``"cuda:0"``, or an instance of ``torch.device``, e.g., ``torch.device("cpu")`` or
``torch.device("cuda", 1)``.
.. hint::
You can use this field to control whether the feature computer
constructed from it performs computation on CPU or CUDA.
.. caution::
If you use a CUDA device, make sure that you have installed a CUDA version
of `PyTorch`_.
Example usage
-------------
The following code from
`<https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/python/tests/test_fbank_options.py>`_
demonstrate the usage of `kaldifeat.FbankOptions`_:
.. literalinclude:: ./code/test_fbank_options.py
:caption: Example usage of `kaldifeat.FbankOptions`_
:language: python

View File

@ -0,0 +1,11 @@
Usage
=====
This section describes how to use feature computers in `kaldifeat`_.
.. toctree::
:maxdepth: 2
fbank_options
fbank
online_fbank

View File

@ -0,0 +1,3 @@
kaldifeat.OnlineFbank
=====================

106
get_version.py Executable file
View File

@ -0,0 +1,106 @@
#!/usr/bin/env python3
import datetime
import os
import platform
import re
import shutil
import torch
def is_macos():
return platform.system() == "Darwin"
def is_windows():
return platform.system() == "Windows"
def with_cuda():
if shutil.which("nvcc") is None:
return False
if is_macos():
return False
return True
def get_pytorch_version():
# if it is 1.7.1+cuda101, then strip +cuda101
return torch.__version__.split("+")[0]
def get_cuda_version():
from torch.utils import collect_env
running_cuda_version = collect_env.get_running_cuda_version(collect_env.run)
cuda_version = torch.version.cuda
if running_cuda_version is not None and cuda_version is not None:
assert cuda_version in running_cuda_version, (
f"PyTorch is built with CUDA version: {cuda_version}.\n"
f"The current running CUDA version is: {running_cuda_version}"
)
return cuda_version
def is_for_pypi():
ans = os.environ.get("KALDIFEAT_IS_FOR_PYPI", None)
return ans is not None
def is_stable():
ans = os.environ.get("KALDIFEAT_IS_STABLE", None)
return ans is not None
def is_for_conda():
ans = os.environ.get("KALDIFEAT_IS_FOR_CONDA", None)
return ans is not None
def get_package_version():
# Set a default CUDA version here so that `pip install kaldifeat`
# uses the default CUDA version.
#
default_cuda_version = "10.1" # CUDA 10.1
if with_cuda():
cuda_version = get_cuda_version()
if is_for_pypi() and default_cuda_version == cuda_version:
cuda_version = ""
pytorch_version = ""
local_version = ""
else:
cuda_version = f"+cuda{cuda_version}"
pytorch_version = get_pytorch_version()
local_version = f"{cuda_version}.torch{pytorch_version}"
else:
pytorch_version = get_pytorch_version()
local_version = f"+cpu.torch{pytorch_version}"
if is_for_conda():
local_version = ""
if is_for_pypi() and is_macos():
local_version = ""
with open("CMakeLists.txt") as f:
content = f.read()
latest_version = re.search(r"set\(kaldifeat_VERSION (.*)\)", content).group(
1
)
latest_version = latest_version.strip('"')
if not is_stable():
dt = datetime.datetime.utcnow()
package_version = f"{latest_version}.dev{dt.year}{dt.month:02d}{dt.day:02d}{local_version}"
else:
package_version = f"{latest_version}"
return package_version
if __name__ == "__main__":
print(get_package_version())

View File

@ -1 +1,4 @@
add_subdirectory(csrc)
if(kaldifeat_BUILD_PYMODULE)
add_subdirectory(python)
endif()

View File

@ -1,7 +1,93 @@
# Copyright (c) 2021 Xiaomi Corporation (author: Fangjun Kuang)
set(kaldifeat_srcs
feature-fbank.cc
feature-functions.cc
feature-mfcc.cc
feature-plp.cc
feature-spectrogram.cc
feature-window.cc
matrix-functions.cc
mel-computations.cc
online-feature.cc
whisper-fbank.cc
)
add_library(kaldifeat_core SHARED ${kaldifeat_srcs})
add_library(kaldifeat_core ${kaldifeat_srcs})
target_link_libraries(kaldifeat_core PUBLIC ${TORCH_LIBRARIES})
target_compile_definitions(kaldifeat_core PUBLIC KALDIFEAT_TORCH_VERSION_MAJOR=${KALDIFEAT_TORCH_VERSION_MAJOR})
target_compile_definitions(kaldifeat_core PUBLIC KALDIFEAT_TORCH_VERSION_MINOR=${KALDIFEAT_TORCH_VERSION_MINOR})
if(APPLE)
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE PYTHON_SITE_PACKAGE_DIR
)
message(STATUS "PYTHON_SITE_PACKAGE_DIR: ${PYTHON_SITE_PACKAGE_DIR}")
target_link_libraries(kaldifeat_core PUBLIC "-L ${PYTHON_SITE_PACKAGE_DIR}/../..")
endif()
add_executable(test_kaldifeat test_kaldifeat.cc)
target_link_libraries(test_kaldifeat PRIVATE kaldifeat_core)
function(kaldifeat_add_test source)
get_filename_component(name ${source} NAME_WE)
add_executable(${name} "${source}")
target_link_libraries(${name}
PRIVATE
kaldifeat_core
gtest
gtest_main
)
# NOTE: We set the working directory here so that
# it works also on windows. The reason is that
# the required DLLs are inside ${TORCH_DIR}/lib
# and they can be found by the exe if the current
# working directory is ${TORCH_DIR}\lib
add_test(NAME "Test.${name}"
COMMAND
$<TARGET_FILE:${name}>
WORKING_DIRECTORY ${TORCH_DIR}/lib
)
endfunction()
if(kaldifeat_BUILD_TESTS)
# please sort the source files alphabetically
set(test_srcs
feature-window-test.cc
online-feature-test.cc
)
foreach(source IN LISTS test_srcs)
kaldifeat_add_test(${source})
endforeach()
endif()
file(MAKE_DIRECTORY
DESTINATION
${PROJECT_BINARY_DIR}/include/kaldifeat/csrc
)
file(GLOB_RECURSE all_headers *.h)
message(STATUS "All headers: ${all_headers}")
file(COPY
${all_headers}
DESTINATION
${PROJECT_BINARY_DIR}/include/kaldifeat/csrc
)
if(BUILD_SHARED_LIBS AND WIN32)
install(TARGETS kaldifeat_core
DESTINATION ../
)
endif()
install(TARGETS kaldifeat_core
DESTINATION ${CMAKE_INSTALL_LIBDIR}
)
install(FILES ${all_headers}
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/kaldifeat/csrc
)

View File

@ -0,0 +1 @@
exclude_files=whisper-mel-bank.h,whisper-v3-mel-bank.h

View File

@ -0,0 +1,76 @@
// kaldifeat/csrc/feature-common-inl.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-common-inl.h
#ifndef KALDIFEAT_CSRC_FEATURE_COMMON_INL_H_
#define KALDIFEAT_CSRC_FEATURE_COMMON_INL_H_
#include "kaldifeat/csrc/feature-window.h"
namespace kaldifeat {
template <class F>
torch::Tensor OfflineFeatureTpl<F>::ComputeFeatures(const torch::Tensor &wave,
float vtln_warp) {
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
torch::Tensor strided_input;
if (wave.dim() == 1) {
strided_input = GetStrided(wave, frame_opts);
} else {
KALDIFEAT_ASSERT(wave.dim() == 2);
KALDIFEAT_ASSERT(wave.size(1) == frame_opts.WindowSize());
strided_input = wave;
}
if (frame_opts.dither != 0.0f) {
strided_input = Dither(strided_input, frame_opts.dither);
}
if (frame_opts.remove_dc_offset) {
torch::Tensor row_means = strided_input.mean(1).unsqueeze(1);
strided_input = strided_input - row_means;
}
bool use_raw_log_energy = computer_.NeedRawLogEnergy();
torch::Tensor log_energy_pre_window;
// torch.finfo(torch.float32).eps
constexpr float kEps = 1.1920928955078125e-07f;
if (use_raw_log_energy) {
// it is true iff use_energy==true and row_energy==true
log_energy_pre_window =
torch::clamp_min(strided_input.pow(2).sum(1), kEps).log();
}
if (frame_opts.preemph_coeff != 0.0f) {
strided_input = Preemphasize(frame_opts.preemph_coeff, strided_input);
}
strided_input = feature_window_function_.Apply(strided_input);
int32_t padding = frame_opts.PaddedWindowSize() - strided_input.size(1);
if (padding > 0) {
#ifdef __ANDROID__
auto padding_value = torch::zeros(
{strided_input.size(0), padding},
torch::dtype(torch::kFloat).device(strided_input.device()));
strided_input = torch::cat({strided_input, padding_value}, 1);
#else
strided_input = torch::nn::functional::pad(
strided_input, torch::nn::functional::PadFuncOptions({0, padding})
.mode(torch::kConstant)
.value(0));
#endif
}
return computer_.Compute(log_energy_pre_window, vtln_warp, strided_input);
}
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_FEATURE_COMMON_INL_H_

View File

@ -0,0 +1,82 @@
// kaldifeat/csrc/feature-common.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-common.h
#ifndef KALDIFEAT_CSRC_FEATURE_COMMON_H_
#define KALDIFEAT_CSRC_FEATURE_COMMON_H_
#include "kaldifeat/csrc/feature-functions.h"
#include "kaldifeat/csrc/feature-window.h"
// See "The torch.fft module in PyTorch 1.7"
// https://github.com/pytorch/pytorch/wiki/The-torch.fft-module-in-PyTorch-1.7
#if KALDIFEAT_TORCH_VERSION_MAJOR > 1 || \
(KALDIFEAT_TORCH_VERSION_MAJOR == 1 && KALDIFEAT_TORCH_VERSION_MINOR > 6)
#include "torch/fft.h"
#define KALDIFEAT_HAS_FFT_NAMESPACE
// It uses torch::fft::rfft
// Its input shape is [x, N], output shape is [x, N/2]
// which is a complex tensor
#else
#include "ATen/Functions.h"
// It uses torch::fft
// Its input shape is [x, N], output shape is [x, N/2, 2]
// which contains the real part [..., ], and imaginary part [..., 1]
#endif
namespace kaldifeat {
template <class F>
class OfflineFeatureTpl {
public:
using Options = typename F::Options;
// Note: feature_window_function_ is the windowing function, which initialized
// using the options class, that we cache at this level.
explicit OfflineFeatureTpl(const Options &opts)
: computer_(opts),
feature_window_function_(computer_.GetFrameOptions(), opts.device) {}
/**
Computes the features for one file (one sequence of features).
This is the newer interface where you specify the sample frequency
of the input waveform.
@param [in] wave The input waveform. It can be either 1-D or 2-D.
If it is a 1-D tensor, we assume it contains
samples of a mono channel sound file.
If it is a 2-D tensor, we assume each row
is a frame of size opts.WindowSize().
@param [in] sample_freq The sampling frequency with which
'wave' was sampled.
if sample_freq is higher than the frequency
specified in the config, we will downsample
the waveform, but if lower, it's an error.
@param [in] vtln_warp The VTLN warping factor (will normally
be 1.0)
@param [out] output The matrix of features, where the row-index
is the frame index.
*/
torch::Tensor ComputeFeatures(const torch::Tensor &wave, float vtln_warp);
int32_t Dim() const { return computer_.Dim(); }
const Options &GetOptions() const { return computer_.GetOptions(); }
const FrameExtractionOptions &GetFrameOptions() const {
return GetOptions().frame_opts;
}
// Copy constructor.
OfflineFeatureTpl(const OfflineFeatureTpl<F> &) = delete;
OfflineFeatureTpl<F> &operator=(const OfflineFeatureTpl<F> &) = delete;
private:
F computer_;
FeatureWindowFunction feature_window_function_;
};
} // namespace kaldifeat
#include "kaldifeat/csrc/feature-common-inl.h"
#endif // KALDIFEAT_CSRC_FEATURE_COMMON_H_

View File

@ -0,0 +1,121 @@
// kaldifeat/csrc/feature-fbank.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-fbank.cc
#include "kaldifeat/csrc/feature-fbank.h"
#include <cmath>
namespace kaldifeat {
std::ostream &operator<<(std::ostream &os, const FbankOptions &opts) {
os << opts.ToString();
return os;
}
FbankComputer::FbankComputer(const FbankOptions &opts) : opts_(opts) {
if (opts.energy_floor > 0.0f) log_energy_floor_ = logf(opts.energy_floor);
// We'll definitely need the filterbanks info for VTLN warping factor 1.0.
// [note: this call caches it.]
GetMelBanks(1.0f);
}
FbankComputer::~FbankComputer() {
for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter)
delete iter->second;
}
const MelBanks *FbankComputer::GetMelBanks(float vtln_warp) {
MelBanks *this_mel_banks = nullptr;
// std::map<float, MelBanks *>::iterator iter = mel_banks_.find(vtln_warp);
auto iter = mel_banks_.find(vtln_warp);
if (iter == mel_banks_.end()) {
this_mel_banks =
new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp, opts_.device);
mel_banks_[vtln_warp] = this_mel_banks;
} else {
this_mel_banks = iter->second;
}
return this_mel_banks;
}
// ans.shape [signal_frame.size(0), this->Dim()]
torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
float vtln_warp,
const torch::Tensor &signal_frame) {
const MelBanks &mel_banks = *(GetMelBanks(vtln_warp));
KALDIFEAT_ASSERT(signal_frame.dim() == 2);
KALDIFEAT_ASSERT(signal_frame.size(1) == opts_.frame_opts.PaddedWindowSize());
// torch.finfo(torch.float32).eps
constexpr float kEps = 1.1920928955078125e-07f;
// Compute energy after window function (not the raw one).
if (opts_.use_energy && !opts_.raw_energy) {
signal_raw_log_energy =
torch::clamp_min(signal_frame.pow(2).sum(1), kEps).log();
}
// note spectrum is in magnitude, not power, because of `abs()`
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
// signal_frame shape: [x, 512]
// spectrum shape [x, 257]
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
#else
// signal_frame shape [x, 512]
// real_imag shape [x, 257, 2],
// where [..., 0] is the real part
// [..., 1] is the imaginary part
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
torch::Tensor real = real_imag.index({"...", 0});
torch::Tensor imag = real_imag.index({"...", 1});
torch::Tensor spectrum = (real.square() + imag.square()).sqrt();
#endif
// remove the last column, i.e., the highest fft bin
spectrum = spectrum.index(
{"...", torch::indexing::Slice(0, -1, torch::indexing::None)});
// Use power instead of magnitude if requested.
if (opts_.use_power) {
spectrum = spectrum.pow(2);
}
torch::Tensor mel_energies = mel_banks.Compute(spectrum);
if (opts_.use_log_fbank) {
// Avoid log of zero (which should be prevented anyway by dithering).
mel_energies = torch::clamp_min(mel_energies, kEps).log();
}
// if use_energy is true, then we get an extra bin. That is,
// if num_mel_bins is 23, the feature will contain 24 bins.
//
// if htk_compat is false, then the 0th bin is the log energy
// if htk_compat is true, then the last bin is the log energy
// Copy energy as first value (or the last, if htk_compat == true).
if (opts_.use_energy) {
if (opts_.energy_floor > 0.0f) {
signal_raw_log_energy =
torch::clamp_min(signal_raw_log_energy, log_energy_floor_);
}
signal_raw_log_energy.unsqueeze_(1);
if (opts_.htk_compat) {
mel_energies = torch::cat({mel_energies, signal_raw_log_energy}, 1);
} else {
mel_energies = torch::cat({signal_raw_log_energy, mel_energies}, 1);
}
}
return mel_energies;
}
} // namespace kaldifeat

View File

@ -0,0 +1,105 @@
// kaldifeat/csrc/feature-fbank.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-fbank.h
#ifndef KALDIFEAT_CSRC_FEATURE_FBANK_H_
#define KALDIFEAT_CSRC_FEATURE_FBANK_H_
#include <map>
#include <string>
#include "kaldifeat/csrc/feature-common.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/mel-computations.h"
namespace kaldifeat {
struct FbankOptions {
FrameExtractionOptions frame_opts;
MelBanksOptions mel_opts;
// append an extra dimension with energy to the filter banks
bool use_energy = false;
float energy_floor = 0.0f; // active iff use_energy==true
// If true, compute log_energy before preemphasis and windowing
// If false, compute log_energy after preemphasis ans windowing
bool raw_energy = true; // active iff use_energy==true
// If true, put energy last (if using energy)
// If false, put energy first
bool htk_compat = false; // active iff use_energy==true
// if true (default), produce log-filterbank, else linear
bool use_log_fbank = true;
// if true (default), use power in filterbank
// analysis, else magnitude.
bool use_power = true;
torch::Device device{"cpu"};
FbankOptions() { mel_opts.num_bins = 23; }
std::string ToString() const {
std::ostringstream os;
os << "FbankOptions(";
os << "frame_opts=" << frame_opts.ToString() << ", ";
os << "mel_opts=" << mel_opts.ToString() << ", ";
os << "use_energy=" << (use_energy ? "True" : "False") << ", ";
os << "energy_floor=" << energy_floor << ", ";
os << "raw_energy=" << (raw_energy ? "True" : "False") << ", ";
os << "htk_compat=" << (htk_compat ? "True" : "False") << ", ";
os << "use_log_fbank=" << (use_log_fbank ? "True" : "False") << ", ";
os << "use_power=" << (use_power ? "True" : "False") << ", ";
os << "device=\"" << device << "\")";
return os.str();
}
};
std::ostream &operator<<(std::ostream &os, const FbankOptions &opts);
class FbankComputer {
public:
using Options = FbankOptions;
explicit FbankComputer(const FbankOptions &opts);
~FbankComputer();
FbankComputer &operator=(const FbankComputer &) = delete;
FbankComputer(const FbankComputer &) = delete;
int32_t Dim() const {
return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0);
}
// if true, compute log_energy_pre_window but after dithering and dc removal
bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; }
const FrameExtractionOptions &GetFrameOptions() const {
return opts_.frame_opts;
}
const FbankOptions &GetOptions() const { return opts_; }
// signal_raw_log_energy is log_energy_pre_window, which is not empty
// iff NeedRawLogEnergy() returns true.
torch::Tensor Compute(torch::Tensor signal_raw_log_energy, float vtln_warp,
const torch::Tensor &signal_frame);
private:
const MelBanks *GetMelBanks(float vtln_warp);
FbankOptions opts_;
float log_energy_floor_;
std::map<float, MelBanks *> mel_banks_; // float is VTLN coefficient.
};
using Fbank = OfflineFeatureTpl<FbankComputer>;
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_FEATURE_FBANK_H_

View File

@ -0,0 +1,33 @@
// kaldifeat/csrc/feature-functions.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-functions.cc
#include "kaldifeat/csrc/feature-functions.h"
#include <cmath>
namespace kaldifeat {
void InitIdftBases(int32_t n_bases, int32_t dimension, torch::Tensor *mat_out) {
float angle = M_PI / (dimension - 1);
float scale = 1.0f / (2 * (dimension - 1));
*mat_out = torch::empty({n_bases, dimension}, torch::kFloat);
float *data = mat_out->data_ptr<float>();
int32_t stride = mat_out->stride(0);
for (int32_t i = 0; i < n_bases; ++i) {
float *this_row = data + i * stride;
this_row[0] = scale;
for (int32_t j = 1; j < dimension - 1; ++j) {
this_row[j] = 2 * scale * std::cos(angle * i * j);
}
this_row[dimension - 1] = scale * std::cos(angle * i * (dimension - 1));
}
}
} // namespace kaldifeat

View File

@ -0,0 +1,18 @@
// kaldifeat/csrc/feature-functions.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-functions.h
#ifndef KALDIFEAT_CSRC_FEATURE_FUNCTIONS_H_
#define KALDIFEAT_CSRC_FEATURE_FUNCTIONS_H_
#include "torch/script.h"
namespace kaldifeat {
void InitIdftBases(int32_t n_bases, int32_t dimension, torch::Tensor *mat_out);
}
#endif // KALDIFEAT_CSRC_FEATURE_FUNCTIONS_H_

View File

@ -0,0 +1,163 @@
// kaldifeat/csrc/feature-mfcc.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-mfcc.cc
#include "kaldifeat/csrc/feature-mfcc.h"
#include "kaldifeat/csrc/matrix-functions.h"
namespace kaldifeat {
std::ostream &operator<<(std::ostream &os, const MfccOptions &opts) {
os << opts.ToString();
return os;
}
MfccComputer::MfccComputer(const MfccOptions &opts) : opts_(opts) {
int32_t num_bins = opts.mel_opts.num_bins;
if (opts.num_ceps > num_bins) {
KALDIFEAT_ERR << "num-ceps cannot be larger than num-mel-bins."
<< " It should be smaller or equal. You provided num-ceps: "
<< opts.num_ceps << " and num-mel-bins: " << num_bins;
}
torch::Tensor dct_matrix = torch::empty({num_bins, num_bins}, torch::kFloat);
ComputeDctMatrix(&dct_matrix);
// Note that we include zeroth dct in either case. If using the
// energy we replace this with the energy. This means a different
// ordering of features than HTK.
using namespace torch::indexing; // It imports: Slice, None // NOLINT
// dct_matrix[:opts.num_cepts, :]
torch::Tensor dct_rows =
dct_matrix.index({Slice(0, opts.num_ceps, None), "..."});
dct_matrix_ = dct_rows.clone().t().to(opts.device);
if (opts.cepstral_lifter != 0.0) {
lifter_coeffs_ = torch::empty({1, opts.num_ceps}, torch::kFloat32);
ComputeLifterCoeffs(opts.cepstral_lifter, &lifter_coeffs_);
lifter_coeffs_ = lifter_coeffs_.to(opts.device);
}
if (opts.energy_floor > 0.0) log_energy_floor_ = logf(opts.energy_floor);
// We'll definitely need the filterbanks info for VTLN warping factor 1.0.
// [note: this call caches it.]
GetMelBanks(1.0);
}
const MelBanks *MfccComputer::GetMelBanks(float vtln_warp) {
MelBanks *this_mel_banks = nullptr;
// std::map<float, MelBanks *>::iterator iter = mel_banks_.find(vtln_warp);
auto iter = mel_banks_.find(vtln_warp);
if (iter == mel_banks_.end()) {
this_mel_banks =
new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp, opts_.device);
mel_banks_[vtln_warp] = this_mel_banks;
} else {
this_mel_banks = iter->second;
}
return this_mel_banks;
}
MfccComputer::~MfccComputer() {
for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter)
delete iter->second;
}
// ans.shape [signal_frame.size(0), this->Dim()]
torch::Tensor MfccComputer::Compute(torch::Tensor signal_raw_log_energy,
float vtln_warp,
const torch::Tensor &signal_frame) {
const MelBanks &mel_banks = *(GetMelBanks(vtln_warp));
KALDIFEAT_ASSERT(signal_frame.dim() == 2);
KALDIFEAT_ASSERT(signal_frame.size(1) == opts_.frame_opts.PaddedWindowSize());
// torch.finfo(torch.float32).eps
constexpr float kEps = 1.1920928955078125e-07f;
// Compute energy after window function (not the raw one).
if (opts_.use_energy && !opts_.raw_energy) {
signal_raw_log_energy =
torch::clamp_min(signal_frame.pow(2).sum(1), kEps).log();
}
// note spectrum is in magnitude, not power, because of `abs()`
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
// signal_frame shape: [x, 512]
// spectrum shape [x, 257
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
#else
// signal_frame shape [x, 512]
// real_imag shape [x, 257, 2],
// where [..., 0] is the real part
// [..., 1] is the imaginary part
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
torch::Tensor real = real_imag.index({"...", 0});
torch::Tensor imag = real_imag.index({"...", 1});
torch::Tensor spectrum = (real.square() + imag.square()).sqrt();
#endif
// remove the last column, i.e., the highest fft bin
spectrum = spectrum.index(
{"...", torch::indexing::Slice(0, -1, torch::indexing::None)});
// Use power instead of magnitude
spectrum = spectrum.pow(2);
torch::Tensor mel_energies = mel_banks.Compute(spectrum);
// Avoid log of zero (which should be prevented anyway by dithering).
mel_energies = torch::clamp_min(mel_energies, kEps).log();
torch::Tensor features = torch::mm(mel_energies, dct_matrix_);
if (opts_.cepstral_lifter != 0.0) {
features = torch::mul(features, lifter_coeffs_);
}
if (opts_.use_energy) {
if (opts_.energy_floor > 0.0f) {
signal_raw_log_energy =
torch::clamp_min(signal_raw_log_energy, log_energy_floor_);
}
// column 0 is replaced by signal_raw_log_energy
//
// features[:, 0] = signal_raw_log_energy
//
features.index({"...", 0}) = signal_raw_log_energy;
}
if (opts_.htk_compat) {
// energy = features[:, 0]
// features[:, :-1] = features[:, 1:]
// features[:, -1] = energy *sqrt(2)
//
// shift left, so the original 0th column
// becomes the last column;
// the original first column becomes the 0th column
features = torch::roll(features, -1, 1);
if (!opts_.use_energy) {
// TODO(fangjun): change the DCT matrix so that we don't need
// to do an extra multiplication here.
//
// scale on C0 (actually removing a scale
// we previously added that's part of one common definition of
// the cosine transform.)
features.index({"...", -1}) *= M_SQRT2;
}
}
return features;
}
} // namespace kaldifeat

View File

@ -0,0 +1,117 @@
// kaldifeat/csrc/feature-mfcc.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-mfcc.h
#ifndef KALDIFEAT_CSRC_FEATURE_MFCC_H_
#define KALDIFEAT_CSRC_FEATURE_MFCC_H_
#include <map>
#include <string>
#include "kaldifeat/csrc/feature-common.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/mel-computations.h"
#include "torch/script.h"
namespace kaldifeat {
/// MfccOptions contains basic options for computing MFCC features.
// (this class is copied from kaldi)
struct MfccOptions {
FrameExtractionOptions frame_opts;
MelBanksOptions mel_opts;
// Number of cepstra in MFCC computation (including C0)
int32_t num_ceps = 13;
// Use energy (not C0) in MFCC computation
bool use_energy = true;
// Floor on energy (absolute, not relative) in MFCC
// computation. Only makes a difference if use_energy=true;
// only necessary if dither=0.0.
// Suggested values: 0.1 or 1.0
float energy_floor = 0.0;
// If true, compute energy before preemphasis and windowing
bool raw_energy = true;
// Constant that controls scaling of MFCCs
float cepstral_lifter = 22.0;
// If true, put energy or C0 last and use a factor of
// sqrt(2) on C0.
// Warning: not sufficient to get HTK compatible features
// (need to change other parameters)
bool htk_compat = false;
torch::Device device{"cpu"};
MfccOptions() { mel_opts.num_bins = 23; }
std::string ToString() const {
std::ostringstream os;
os << "MfccOptions(";
os << "frame_opts=" << frame_opts.ToString() << ", ";
os << "mel_opts=" << mel_opts.ToString() << ", ";
os << "num_ceps=" << num_ceps << ", ";
os << "use_energy=" << (use_energy ? "True" : "False") << ", ";
os << "energy_floor=" << energy_floor << ", ";
os << "raw_energy=" << (raw_energy ? "True" : "False") << ", ";
os << "cepstral_lifter=" << cepstral_lifter << ", ";
os << "htk_compat=" << (htk_compat ? "True" : "False") << ", ";
os << "device=\"" << device << "\")";
return os.str();
}
};
std::ostream &operator<<(std::ostream &os, const MfccOptions &opts);
class MfccComputer {
public:
using Options = MfccOptions;
explicit MfccComputer(const MfccOptions &opts);
~MfccComputer();
MfccComputer &operator=(const MfccComputer &) = delete;
MfccComputer(const MfccComputer &) = delete;
int32_t Dim() const { return opts_.num_ceps; }
bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; }
const FrameExtractionOptions &GetFrameOptions() const {
return opts_.frame_opts;
}
const MfccOptions &GetOptions() const { return opts_; }
// signal_raw_log_energy is log_energy_pre_window, which is not empty
// iff NeedRawLogEnergy() returns true.
torch::Tensor Compute(torch::Tensor signal_raw_log_energy, float vtln_warp,
const torch::Tensor &signal_frame);
private:
const MelBanks *GetMelBanks(float vtln_warp);
MfccOptions opts_;
torch::Tensor lifter_coeffs_; // 1-D tensor
// Note we save a transposed version of dct_matrix_
// dct_matrix_.rows is num_mel_bins
// dct_matrix_.cols is num_ceps
torch::Tensor dct_matrix_; // matrix we right-multiply by to perform DCT.
float log_energy_floor_;
std::map<float, MelBanks *> mel_banks_; // float is VTLN coefficient.
};
using Mfcc = OfflineFeatureTpl<MfccComputer>;
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_FEATURE_MFCC_H_

View File

@ -0,0 +1,185 @@
// kaldifeat/csrc/feature-plp.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-plp.cc
#include "kaldifeat/csrc/feature-plp.h"
#include "kaldifeat/csrc/feature-functions.h"
namespace kaldifeat {
std::ostream &operator<<(std::ostream &os, const PlpOptions &opts) {
os << opts.ToString();
return os;
}
PlpComputer::PlpComputer(const PlpOptions &opts) : opts_(opts) {
// our num-ceps includes C0.
KALDIFEAT_ASSERT(opts_.num_ceps <= opts_.lpc_order + 1);
if (opts.cepstral_lifter != 0.0) {
lifter_coeffs_ = torch::empty({1, opts.num_ceps}, torch::kFloat32);
ComputeLifterCoeffs(opts.cepstral_lifter, &lifter_coeffs_);
lifter_coeffs_ = lifter_coeffs_.to(opts.device);
}
InitIdftBases(opts_.lpc_order + 1, opts_.mel_opts.num_bins + 2, &idft_bases_);
// CAUTION: we save a transposed version of idft_bases_
idft_bases_ = idft_bases_.to(opts.device).t();
if (opts.energy_floor > 0.0) log_energy_floor_ = logf(opts.energy_floor);
// We'll definitely need the filterbanks info for VTLN warping factor 1.0.
// [note: this call caches it.]
GetMelBanks(1.0);
}
PlpComputer::~PlpComputer() {
for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter)
delete iter->second;
for (auto iter = equal_loudness_.begin(); iter != equal_loudness_.end();
++iter)
delete iter->second;
}
const MelBanks *PlpComputer::GetMelBanks(float vtln_warp) {
MelBanks *this_mel_banks = nullptr;
// std::map<float, MelBanks *>::iterator iter = mel_banks_.find(vtln_warp);
auto iter = mel_banks_.find(vtln_warp);
if (iter == mel_banks_.end()) {
this_mel_banks =
new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp, opts_.device);
mel_banks_[vtln_warp] = this_mel_banks;
} else {
this_mel_banks = iter->second;
}
return this_mel_banks;
}
const torch::Tensor *PlpComputer::GetEqualLoudness(float vtln_warp) {
const MelBanks *this_mel_banks = GetMelBanks(vtln_warp);
torch::Tensor *ans = NULL;
auto iter = equal_loudness_.find(vtln_warp);
if (iter == equal_loudness_.end()) {
ans = new torch::Tensor;
GetEqualLoudnessVector(*this_mel_banks, ans);
*ans = ans->to(opts_.device);
equal_loudness_[vtln_warp] = ans;
} else {
ans = iter->second;
}
return ans;
}
// ans.shape [signal_frame.size(0), this->Dim()]
torch::Tensor PlpComputer::Compute(torch::Tensor signal_raw_log_energy,
float vtln_warp,
const torch::Tensor &signal_frame) {
KALDIFEAT_ASSERT(signal_frame.dim() == 2);
KALDIFEAT_ASSERT(signal_frame.size(1) == opts_.frame_opts.PaddedWindowSize());
const MelBanks &mel_banks = *GetMelBanks(vtln_warp);
const torch::Tensor &equal_loudness = *GetEqualLoudness(vtln_warp);
// torch.finfo(torch.float32).eps
constexpr float kEps = 1.1920928955078125e-07f;
// Compute energy after window function (not the raw one).
if (opts_.use_energy && !opts_.raw_energy) {
signal_raw_log_energy =
torch::clamp_min(signal_frame.pow(2).sum(1), kEps).log();
}
// note spectrum is in magnitude, not power, because of `abs()`
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
// signal_frame shape: [x, 512]
// spectrum shape [x, 257
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
#else
// signal_frame shape [x, 512]
// real_imag shape [x, 257, 2],
// where [..., 0] is the real part
// [..., 1] is the imaginary part
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
torch::Tensor real = real_imag.index({"...", 0});
torch::Tensor imag = real_imag.index({"...", 1});
torch::Tensor spectrum = (real.square() + imag.square()).sqrt();
#endif
// remove the last column, i.e., the highest fft bin
spectrum = spectrum.index(
{"...", torch::indexing::Slice(0, -1, torch::indexing::None)});
// Use power instead of magnitude
spectrum = spectrum.pow(2);
torch::Tensor mel_energies = mel_banks.Compute(spectrum);
mel_energies = torch::mul(mel_energies, equal_loudness);
mel_energies = mel_energies.pow(opts_.compress_factor);
// duplicate first and last elements
//
// left_padding = wave[:num_left_padding].flip(dims=(0,))
// first = mel_energies[:, 0]
// first.shape [num_frames, 1]
torch::Tensor first = mel_energies.index({"...", 0}).unsqueeze(-1);
// last = mel_energies[:, -1]
// last.shape [num_frames, 1]
torch::Tensor last = mel_energies.index({"...", -1}).unsqueeze(-1);
mel_energies = torch::cat({first, mel_energies, last}, 1);
torch::Tensor autocorr_coeffs = torch::mm(mel_energies, idft_bases_);
torch::Tensor lpc_coeffs;
torch::Tensor residual_log_energy = ComputeLpc(autocorr_coeffs, &lpc_coeffs);
residual_log_energy = torch::clamp_min(residual_log_energy, kEps);
torch::Tensor raw_cepstrum = Lpc2Cepstrum(lpc_coeffs);
// torch.cat((residual_log_energy.unsqueeze(-1),
// raw_cepstrum[:opts.num_ceps-1]), 1)
//
using namespace torch::indexing; // It imports: Slice, None // NOLINT
torch::Tensor features = torch::cat(
{residual_log_energy.unsqueeze(-1),
raw_cepstrum.index({"...", Slice(0, opts_.num_ceps - 1, None)})},
1);
if (opts_.cepstral_lifter != 0.0) {
features = torch::mul(features, lifter_coeffs_);
}
if (opts_.cepstral_scale != 1.0) {
features = features * opts_.cepstral_scale;
}
if (opts_.use_energy) {
if (opts_.energy_floor > 0.0f) {
signal_raw_log_energy =
torch::clamp_min(signal_raw_log_energy, log_energy_floor_);
}
// column 0 is replaced by signal_raw_log_energy
//
// features[:, 0] = signal_raw_log_energy
//
features.index({"...", 0}) = signal_raw_log_energy;
}
if (opts_.htk_compat) { // reorder the features.
// shift left, so the original 0th column
// becomes the last column;
// the original first column becomes the 0th column
features = torch::roll(features, -1, 1);
}
return features;
}
} // namespace kaldifeat

View File

@ -0,0 +1,129 @@
// kaldifeat/csrc/feature-plp.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-plp.h
#ifndef KALDIFEAT_CSRC_FEATURE_PLP_H_
#define KALDIFEAT_CSRC_FEATURE_PLP_H_
#include <map>
#include <string>
#include "kaldifeat/csrc/feature-common.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/mel-computations.h"
#include "torch/script.h"
namespace kaldifeat {
/// PlpOptions contains basic options for computing PLP features.
/// It only includes things that can be done in a "stateless" way, i.e.
/// it does not include energy max-normalization.
/// It does not include delta computation.
struct PlpOptions {
FrameExtractionOptions frame_opts;
MelBanksOptions mel_opts;
// Order of LPC analysis in PLP computation
//
// 12 seems to be common for 16kHz-sampled data. For 8kHz-sampled
// data, 15 may be better.
int32_t lpc_order = 12;
// Number of cepstra in PLP computation (including C0)
int32_t num_ceps = 13;
bool use_energy = true; // use energy; else C0
// Floor on energy (absolute, not relative) in PLP computation.
// Only makes a difference if --use-energy=true; only necessary if
// dither is 0.0. Suggested values: 0.1 or 1.0
float energy_floor = 0.0;
// If true, compute energy before preemphasis and windowing
bool raw_energy = true;
// Compression factor in PLP computation
float compress_factor = 0.33333;
// Constant that controls scaling of PLPs
int32_t cepstral_lifter = 22;
// Scaling constant in PLP computation
float cepstral_scale = 1.0;
bool htk_compat = false; // if true, put energy/C0 last and introduce a
// factor of sqrt(2) on C0 to be the same as HTK.
//
torch::Device device{"cpu"};
PlpOptions() { mel_opts.num_bins = 23; }
std::string ToString() const {
std::ostringstream os;
os << "PlpOptions(";
os << "frame_opts=" << frame_opts.ToString() << ", ";
os << "mel_opts=" << mel_opts.ToString() << ", ";
os << "lpc_order=" << lpc_order << ", ";
os << "num_ceps=" << num_ceps << ", ";
os << "use_energy=" << (use_energy ? "True" : "False") << ", ";
os << "energy_floor=" << energy_floor << ", ";
os << "raw_energy=" << (raw_energy ? "True" : "False") << ", ";
os << "compress_factor=" << compress_factor << ", ";
os << "cepstral_lifter=" << cepstral_lifter << ", ";
os << "cepstral_scale=" << cepstral_scale << ", ";
os << "htk_compat=" << (htk_compat ? "True" : "False") << ", ";
os << "device=\"" << device << "\")";
return os.str();
}
};
std::ostream &operator<<(std::ostream &os, const PlpOptions &opts);
class PlpComputer {
public:
using Options = PlpOptions;
explicit PlpComputer(const PlpOptions &opts);
~PlpComputer();
PlpComputer &operator=(const PlpComputer &) = delete;
PlpComputer(const PlpComputer &) = delete;
int32_t Dim() const { return opts_.num_ceps; }
bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; }
const FrameExtractionOptions &GetFrameOptions() const {
return opts_.frame_opts;
}
const PlpOptions &GetOptions() const { return opts_; }
// signal_raw_log_energy is log_energy_pre_window, which is not empty
// iff NeedRawLogEnergy() returns true.
torch::Tensor Compute(torch::Tensor signal_raw_log_energy, float vtln_warp,
const torch::Tensor &signal_frame);
private:
const MelBanks *GetMelBanks(float vtln_warp);
const torch::Tensor *GetEqualLoudness(float vtln_warp);
PlpOptions opts_;
torch::Tensor lifter_coeffs_;
torch::Tensor idft_bases_; // 2-D tensor, kFloat. Caution: it is transposed
float log_energy_floor_;
std::map<float, MelBanks *> mel_banks_; // float is VTLN coefficient.
// value is a 1-D torch.Tensor
std::map<float, torch::Tensor *> equal_loudness_;
};
using Plp = OfflineFeatureTpl<PlpComputer>;
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_FEATURE_PLP_H_

View File

@ -0,0 +1,78 @@
// kaldifeat/csrc/feature-spectrogram.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-spectrogram.cc
#include "kaldifeat/csrc/feature-spectrogram.h"
namespace kaldifeat {
std::ostream &operator<<(std::ostream &os, const SpectrogramOptions &opts) {
os << opts.ToString();
return os;
}
SpectrogramComputer::SpectrogramComputer(const SpectrogramOptions &opts)
: opts_(opts) {
if (opts.energy_floor > 0.0) log_energy_floor_ = logf(opts.energy_floor);
}
// ans.shape [signal_frame.size(0), this->Dim()]
torch::Tensor SpectrogramComputer::Compute(torch::Tensor signal_raw_log_energy,
float vtln_warp,
const torch::Tensor &signal_frame) {
KALDIFEAT_ASSERT(signal_frame.dim() == 2);
KALDIFEAT_ASSERT(signal_frame.size(1) == opts_.frame_opts.PaddedWindowSize());
// torch.finfo(torch.float32).eps
constexpr float kEps = 1.1920928955078125e-07f;
// Compute energy after window function (not the raw one).
if (!opts_.raw_energy) {
signal_raw_log_energy =
torch::clamp_min(signal_frame.pow(2).sum(1), kEps).log();
}
// note spectrum is in magnitude, not power, because of `abs()`
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
// signal_frame shape: [x, 512]
// spectrum shape [x, 257
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
#else
// signal_frame shape [x, 512]
// real_imag shape [x, 257, 2],
// where [..., 0] is the real part
// [..., 1] is the imaginary part
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
torch::Tensor real = real_imag.index({"...", 0});
torch::Tensor imag = real_imag.index({"...", 1});
torch::Tensor spectrum = (real.square() + imag.square()).sqrt();
#endif
if (opts_.return_raw_fft) {
KALDIFEAT_ERR << "return raw fft is not supported yet";
}
// compute power spectrum
spectrum = spectrum.pow(2);
// NOTE: take the log
spectrum = torch::clamp_min(spectrum, kEps).log();
if (opts_.energy_floor > 0.0f) {
signal_raw_log_energy =
torch::clamp_min(signal_raw_log_energy, log_energy_floor_);
}
// The zeroth spectrogram component is always set to the signal energy,
// instead of the square of the constant component of the signal.
//
// spectrum[:,0] = signal_raw_log_energy
spectrum.index({"...", 0}) = signal_raw_log_energy;
return spectrum;
}
} // namespace kaldifeat

View File

@ -0,0 +1,92 @@
// kaldifeat/csrc/feature-spectrogram.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-spectrogram.h
#ifndef KALDIFEAT_CSRC_FEATURE_SPECTROGRAM_H_
#define KALDIFEAT_CSRC_FEATURE_SPECTROGRAM_H_
#include <string>
#include "kaldifeat/csrc/feature-common.h"
#include "kaldifeat/csrc/feature-window.h"
#include "torch/script.h"
namespace kaldifeat {
struct SpectrogramOptions {
FrameExtractionOptions frame_opts;
// Floor on energy (absolute, not relative) in Spectrogram
// computation. Caution: this floor is applied to the
// zeroth component, representing the total signal energy.
// The floor on the individual spectrogram elements is fixed at
// std::numeric_limits<float>::epsilon()
float energy_floor = 0.0;
// If true, compute energy before preemphasis and windowing
bool raw_energy = true;
// If true, return raw FFT complex numbers instead of log magnitudes
// Not implemented yet
bool return_raw_fft = false;
torch::Device device{"cpu"};
std::string ToString() const {
std::ostringstream os;
os << "SpectrogramOptions(";
os << "frame_opts=" << frame_opts.ToString() << ", ";
os << "energy_floor=" << energy_floor << ", ";
os << "raw_energy=" << (raw_energy ? "True" : "False") << ", ";
os << "return_raw_fft=" << (return_raw_fft ? "True" : "False") << ", ";
os << "device=\"" << device << "\")";
return os.str();
}
};
std::ostream &operator<<(std::ostream &os, const SpectrogramOptions &opts);
class SpectrogramComputer {
public:
using Options = SpectrogramOptions;
explicit SpectrogramComputer(const SpectrogramOptions &opts);
~SpectrogramComputer() = default;
const FrameExtractionOptions &GetFrameOptions() const {
return opts_.frame_opts;
}
const SpectrogramOptions &GetOptions() const { return opts_; }
int32_t Dim() const {
if (opts_.return_raw_fft) {
return opts_.frame_opts.PaddedWindowSize();
} else {
return opts_.frame_opts.PaddedWindowSize() / 2 + 1;
}
}
bool NeedRawLogEnergy() const { return opts_.raw_energy; }
// signal_raw_log_energy is log_energy_pre_window, which is not empty
// iff NeedRawLogEnergy() returns true.
//
// vtln_warp is ignored by this function, it's only
// needed for interface compatibility.
torch::Tensor Compute(torch::Tensor signal_raw_log_energy, float vtln_warp,
const torch::Tensor &signal_frame);
private:
SpectrogramOptions opts_;
float log_energy_floor_;
};
using Spectrogram = OfflineFeatureTpl<SpectrogramComputer>;
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_FEATURE_SPECTROGRAM_H_

View File

@ -0,0 +1,82 @@
// kaldifeat/csrc/feature-window-test.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#include "kaldifeat/csrc/feature-window.h"
#include "gtest/gtest.h"
namespace kaldifeat {
TEST(FeatureWindow, NumFrames) {
FrameExtractionOptions opts;
opts.samp_freq = 1000;
opts.frame_length_ms = 4;
opts.frame_shift_ms = 2;
int32_t frame_length = opts.samp_freq / 1000 * opts.frame_length_ms;
int32_t frame_shift = opts.samp_freq / 1000 * opts.frame_shift_ms;
for (int32_t num_samples = 10; num_samples < 1000; ++num_samples) {
opts.snip_edges = true;
int32_t num_frames = NumFrames(num_samples, opts);
int32_t expected_num_frames =
(num_samples - frame_length) / frame_shift + 1;
ASSERT_EQ(num_frames, expected_num_frames);
opts.snip_edges = false;
num_frames = NumFrames(num_samples, opts);
expected_num_frames = (num_samples + frame_shift / 2) / frame_shift;
ASSERT_EQ(num_frames, expected_num_frames);
}
}
TEST(FeatureWindow, FirstSampleOfFrame) {
FrameExtractionOptions opts;
opts.samp_freq = 1000;
opts.frame_length_ms = 4;
opts.frame_shift_ms = 2;
// samples: [a, b, c, d, e, f]
int32_t num_samples = 6;
opts.snip_edges = true;
ASSERT_EQ(NumFrames(num_samples, opts), 2);
EXPECT_EQ(FirstSampleOfFrame(0, opts), 0);
EXPECT_EQ(FirstSampleOfFrame(1, opts), 2);
// now for snip edges if false
opts.snip_edges = false;
ASSERT_EQ(NumFrames(num_samples, opts), 3);
EXPECT_EQ(FirstSampleOfFrame(0, opts), -1);
EXPECT_EQ(FirstSampleOfFrame(1, opts), 1);
EXPECT_EQ(FirstSampleOfFrame(2, opts), 3);
}
TEST(FeatureWindow, GetStrided) {
FrameExtractionOptions opts;
opts.samp_freq = 1000;
opts.frame_length_ms = 4;
opts.frame_shift_ms = 2;
// [0 1 2 3 4 5]
torch::Tensor samples = torch::arange(0, 6).to(torch::kFloat);
opts.snip_edges = true;
auto frames = GetStrided(samples, opts);
// 0 1 2 3
// 2 3 4 5
std::vector<float> v = {0, 1, 2, 3, 2, 3, 4, 5};
torch::Tensor expected =
torch::from_blob(v.data(), {int64_t(v.size())}, torch::kFloat32);
EXPECT_TRUE(frames.flatten().allclose(expected));
// 0 0 1 2
// 1 2 3 4
// 3 4 5 5
opts.snip_edges = false;
frames = GetStrided(samples, opts);
v = {0, 0, 1, 2, 1, 2, 3, 4, 3, 4, 5, 5};
expected = torch::from_blob(v.data(), {int64_t(v.size())}, torch::kFloat32);
EXPECT_TRUE(frames.flatten().allclose(expected));
}
} // namespace kaldifeat

View File

@ -7,8 +7,7 @@
#include "kaldifeat/csrc/feature-window.h"
#include <cmath>
#include "torch/torch.h"
#include <vector>
#ifndef M_2PI
#define M_2PI 6.283185307179586476925286766559005
@ -16,8 +15,13 @@
namespace kaldifeat {
FeatureWindowFunction::FeatureWindowFunction(
const FrameExtractionOptions &opts) {
std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts) {
os << opts.ToString();
return os;
}
FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts,
torch::Device device) {
int32_t frame_length = opts.WindowSize();
KALDIFEAT_ASSERT(frame_length > 0);
@ -25,6 +29,13 @@ FeatureWindowFunction::FeatureWindowFunction(
float *window_data = window.data_ptr<float>();
double a = M_2PI / (frame_length - 1);
if (opts.window_type == "hann") {
// see https://pytorch.org/docs/stable/generated/torch.hann_window.html
// We assume periodic is true
a = M_2PI / frame_length;
}
for (int32_t i = 0; i < frame_length; i++) {
double i_fl = static_cast<double>(i);
if (opts.window_type == "hanning") {
@ -35,6 +46,8 @@ FeatureWindowFunction::FeatureWindowFunction(
window_data[i] = sin(0.5 * a * i_fl);
} else if (opts.window_type == "hamming") {
window_data[i] = 0.54 - 0.46 * cos(a * i_fl);
} else if (opts.window_type == "hann") {
window_data[i] = 0.50 - 0.50 * cos(a * i_fl);
} else if (opts.window_type ==
"povey") { // like hamming but goes to zero at edges.
window_data[i] = pow(0.5 - 0.5 * cos(a * i_fl), 0.85);
@ -47,6 +60,197 @@ FeatureWindowFunction::FeatureWindowFunction(
KALDIFEAT_ERR << "Invalid window type " << opts.window_type;
}
}
window = window.unsqueeze(0);
if (window.device() != device) {
window = window.to(device);
}
}
torch::Tensor FeatureWindowFunction::Apply(const torch::Tensor &wave) const {
KALDIFEAT_ASSERT(wave.dim() == 2);
KALDIFEAT_ASSERT(wave.size(1) == window.size(1));
return wave.mul(window);
}
int64_t FirstSampleOfFrame(int32_t frame, const FrameExtractionOptions &opts) {
int64_t frame_shift = opts.WindowShift();
if (opts.snip_edges) {
return frame * frame_shift;
} else {
int64_t midpoint_of_frame = frame_shift * frame + frame_shift / 2,
beginning_of_frame = midpoint_of_frame - opts.WindowSize() / 2;
return beginning_of_frame;
}
}
int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts,
bool flush /*= true*/) {
int64_t frame_shift = opts.WindowShift();
int64_t frame_length = opts.WindowSize();
if (opts.snip_edges) {
// with --snip-edges=true (the default), we use a HTK-like approach to
// determining the number of frames-- all frames have to fit completely into
// the waveform, and the first frame begins at sample zero.
if (num_samples < frame_length)
return 0;
else
return (1 + ((num_samples - frame_length) / frame_shift));
// You can understand the expression above as follows: 'num_samples -
// frame_length' is how much room we have to shift the frame within the
// waveform; 'frame_shift' is how much we shift it each time; and the ratio
// is how many times we can shift it (integer arithmetic rounds down).
} else {
// if --snip-edges=false, the number of frames is determined by rounding the
// (file-length / frame-shift) to the nearest integer. The point of this
// formula is to make the number of frames an obvious and predictable
// function of the frame shift and signal length, which makes many
// segmentation-related questions simpler.
//
// Because integer division in C++ rounds toward zero, we add (half the
// frame-shift minus epsilon) before dividing, to have the effect of
// rounding towards the closest integer.
int32_t num_frames = (num_samples + (frame_shift / 2)) / frame_shift;
if (flush) return num_frames;
// note: 'end' always means the last plus one, i.e. one past the last.
int64_t end_sample_of_last_frame =
FirstSampleOfFrame(num_frames - 1, opts) + frame_length;
// the following code is optimized more for clarity than efficiency.
// If flush == false, we can't output frames that extend past the end
// of the signal.
while (num_frames > 0 && end_sample_of_last_frame > num_samples) {
num_frames--;
end_sample_of_last_frame -= frame_shift;
}
return num_frames;
}
}
torch::Tensor GetStrided(const torch::Tensor &wave,
const FrameExtractionOptions &opts) {
KALDIFEAT_ASSERT(wave.dim() == 1);
std::vector<int64_t> strides = {opts.WindowShift() * wave.strides()[0],
wave.strides()[0]};
int64_t num_samples = wave.size(0);
int32_t num_frames = NumFrames(num_samples, opts);
std::vector<int64_t> sizes = {num_frames, opts.WindowSize()};
if (opts.snip_edges) {
return wave.as_strided(sizes, strides);
}
int32_t frame_length = opts.samp_freq / 1000 * opts.frame_length_ms;
int32_t frame_shift = opts.samp_freq / 1000 * opts.frame_shift_ms;
int64_t num_new_samples = (num_frames - 1) * frame_shift + frame_length;
int32_t num_padding = num_new_samples - num_samples;
int32_t num_left_padding = (frame_length - frame_shift) / 2;
int32_t num_right_padding = num_padding - num_left_padding;
// left_padding = wave[:num_left_padding].flip(dims=(0,))
torch::Tensor left_padding =
wave.index({torch::indexing::Slice(0, num_left_padding, 1)}).flip({0});
// right_padding = wave[-num_righ_padding:].flip(dims=(0,))
torch::Tensor right_padding =
wave.index({torch::indexing::Slice(-num_right_padding,
torch::indexing::None, 1)})
.flip({0});
torch::Tensor new_wave = torch::cat({left_padding, wave, right_padding}, 0);
return new_wave.as_strided(sizes, strides);
}
torch::Tensor Dither(const torch::Tensor &wave, float dither_value) {
if (dither_value == 0.0f) return wave;
torch::Tensor rand_gauss = torch::randn_like(wave);
#if 1
return wave + rand_gauss * dither_value;
#else
// use in-place version of wave and change it to pointer type
wave_->add_(rand_gauss, dither_value);
#endif
}
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave) {
if (preemph_coeff == 0.0f) return wave;
KALDIFEAT_ASSERT(preemph_coeff >= 0.0f && preemph_coeff <= 1.0f);
torch::Tensor ans = torch::empty_like(wave);
using torch::indexing::None;
using torch::indexing::Slice;
// right = wave[:, 1:]
torch::Tensor right = wave.index({"...", Slice(1, None, None)});
// current = wave[:, 0:-1]
torch::Tensor current = wave.index({"...", Slice(0, -1, None)});
// ans[1:, :] = wave[:, 1:] - preemph_coeff * wave[:, 0:-1]
ans.index({"...", Slice(1, None, None)}) = right - preemph_coeff * current;
ans.index({"...", 0}) = wave.index({"...", 0}) * (1 - preemph_coeff);
return ans;
}
torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
int32_t f, const FrameExtractionOptions &opts) {
KALDIFEAT_ASSERT(sample_offset >= 0 && wave.numel() != 0);
int32_t frame_length = opts.WindowSize();
int64_t num_samples = sample_offset + wave.numel();
int64_t start_sample = FirstSampleOfFrame(f, opts);
int64_t end_sample = start_sample + frame_length;
if (opts.snip_edges) {
KALDIFEAT_ASSERT(start_sample >= sample_offset &&
end_sample <= num_samples);
} else {
KALDIFEAT_ASSERT(sample_offset == 0 || start_sample >= sample_offset);
}
// wave_start and wave_end are start and end indexes into 'wave', for the
// piece of wave that we're trying to extract.
int32_t wave_start = static_cast<int32_t>(start_sample - sample_offset);
int32_t wave_end = wave_start + frame_length;
if (wave_start >= 0 && wave_end <= wave.numel()) {
// the normal case -- no edge effects to consider.
// return wave[wave_start:wave_end]
return wave.index({torch::indexing::Slice(wave_start, wave_end)});
} else {
torch::Tensor window = torch::empty({frame_length}, torch::kFloat);
auto p_window = window.accessor<float, 1>();
auto p_wave = wave.accessor<float, 1>();
// Deal with any end effects by reflection, if needed. This code will only
// be reached for about two frames per utterance, so we don't concern
// ourselves excessively with efficiency.
int32_t wave_dim = wave.numel();
for (int32_t s = 0; s != frame_length; ++s) {
int32_t s_in_wave = s + wave_start;
while (s_in_wave < 0 || s_in_wave >= wave_dim) {
// reflect around the beginning or end of the wave.
// e.g. -1 -> 0, -2 -> 1.
// dim -> dim - 1, dim + 1 -> dim - 2.
// the code supports repeated reflections, although this
// would only be needed in pathological cases.
if (s_in_wave < 0) {
s_in_wave = -s_in_wave - 1;
} else {
s_in_wave = 2 * wave_dim - 1 - s_in_wave;
}
}
p_window[s] = p_wave[s_in_wave];
}
return window;
}
}
} // namespace kaldifeat

View File

@ -4,8 +4,11 @@
// This file is copied/modified from kaldi/src/feat/feature-window.h
#include <string>
#include "kaldifeat/csrc/log.h"
#include "torch/torch.h"
#include "torch/all.h"
#include "torch/script.h"
#ifndef KALDIFEAT_CSRC_FEATURE_WINDOW_H_
#define KALDIFEAT_CSRC_FEATURE_WINDOW_H_
@ -39,8 +42,12 @@ struct FrameExtractionOptions {
bool round_to_power_of_two = true;
float blackman_coeff = 0.42f;
bool snip_edges = true;
bool allow_downsample = false;
bool allow_upsample = false;
// bool allow_downsample = false;
// bool allow_upsample = false;
// Used for streaming feature extraction. It indicates the number
// of feature frames to keep in the recycling vector. -1 means to
// keep all feature frames.
int32_t max_feature_vectors = -1;
int32_t WindowShift() const {
@ -53,14 +60,105 @@ struct FrameExtractionOptions {
return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(WindowSize())
: WindowSize());
}
std::string ToString() const {
std::ostringstream os;
os << "FrameExtractionOptions(";
os << "samp_freq=" << samp_freq << ", ";
os << "frame_shift_ms=" << frame_shift_ms << ", ";
os << "frame_length_ms=" << frame_length_ms << ", ";
os << "dither=" << dither << ", ";
os << "preemph_coeff=" << preemph_coeff << ", ";
os << "remove_dc_offset=" << (remove_dc_offset ? "True" : "False") << ", ";
os << "window_type=" << '"' << window_type << '"' << ", ";
os << "round_to_power_of_two=" << (round_to_power_of_two ? "True" : "False")
<< ", ";
os << "blackman_coeff=" << blackman_coeff << ", ";
os << "snip_edges=" << (snip_edges ? "True" : "False") << ", ";
os << "max_feature_vectors=" << max_feature_vectors << ")";
return os.str();
}
};
struct FeatureWindowFunction {
std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts);
class FeatureWindowFunction {
public:
FeatureWindowFunction() = default;
explicit FeatureWindowFunction(const FrameExtractionOptions &opts);
FeatureWindowFunction(const FrameExtractionOptions &opts,
torch::Device device);
torch::Tensor Apply(const torch::Tensor &wave) const;
private:
torch::Tensor window;
};
/**
This function returns the number of frames that we can extract from a wave
file with the given number of samples in it (assumed to have the same
sampling rate as specified in 'opts').
@param [in] num_samples The number of samples in the wave file.
@param [in] opts The frame-extraction options class
@param [in] flush True if we are asserting that this number of samples
is 'all there is', false if we expecting more data to possibly come in. This
only makes a difference to the answer
if opts.snips_edges== false. For offline feature extraction you always want
flush == true. In an online-decoding context, once you know (or decide) that
no more data is coming in, you'd call it with flush == true at the end to
flush out any remaining data.
*/
int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts,
bool flush = true);
int64_t FirstSampleOfFrame(int32_t frame, const FrameExtractionOptions &opts);
// return a 2-d tensor with shape [num_frames, opts.WindowSize()]
//
// Suppose the wave contains samples [a, b, c, d, e, f],
// windows size is 4 and window shift is 2
//
// if opt.snip_edges is true, it returns:
// a b c d
// c d e f
//
// if opt.snip_edges is false, it returns
// a a b c
// b c d e
// d e f f
// (Note, it use reflections at the end. That is
// abcdef is reflected to fedcba|abcdef|fedcba)
torch::Tensor GetStrided(const torch::Tensor &wave,
const FrameExtractionOptions &opts);
torch::Tensor Dither(const torch::Tensor &wave, float dither_value);
torch::Tensor Preemphasize(float preemph_coeff, const torch::Tensor &wave);
/*
ExtractWindow() extracts "frame_length" samples from the given waveform.
Note: This function only extracts "frame_length" samples
from the input waveform, without any further processing.
@param [in] sample_offset If 'wave' is not the entire waveform, but
part of it to the left has been discarded, then the
number of samples prior to 'wave' that we have
already discarded. Set this to zero if you are
processing the entire waveform in one piece, or
if you get 'no matching function' compilation
errors when updating the code.
@param [in] wave The waveform
@param [in] f The frame index to be extracted, with
0 <= f < NumFrames(sample_offset + wave.numel(), opts, true)
@param [in] opts The options class to be used
@return Return a tensor containing "frame_length" samples extracted from
`wave`, without any further processing. Its shape is
(1, frame_length).
*/
torch::Tensor ExtractWindow(int64_t sample_offset, const torch::Tensor &wave,
int32_t f, const FrameExtractionOptions &opts);
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_FEATURE_WINDOW_H_

View File

@ -0,0 +1,39 @@
#!/usr/bin/env python3
# Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
import librosa
import numpy as np
def main():
m = librosa.filters.mel(sr=16000, n_fft=400, n_mels=128)
assert m.shape == (128, 201)
s = "// Auto-generated. Do NOT edit!\n\n"
s += "// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)\n\n"
s += "\n"
s += "#ifndef KALDIFEAT_CSRC_WHISPER_V3_MEL_BANK_H_\n"
s += "#define KALDIFEAT_CSRC_WHISPER_V3_MEL_BANK_H_\n"
s += "namespace kaldifeat {\n\n"
s += f"constexpr int32_t kWhisperV3MelRows = {m.shape[0]};\n"
s += f"constexpr int32_t kWhisperV3MelCols = {m.shape[1]};\n"
s += "\n"
s += "constexpr float kWhisperV3MelArray[] = {\n"
sep = ""
for i, f in enumerate(m.reshape(-1).tolist()):
s += f"{sep}{f:.8f}"
sep = ", "
if i and i % 7 == 0:
s += ",\n"
sep = ""
s += "};\n\n"
s += "} // namespace kaldifeat\n\n"
s += "#endif // KALDIFEAT_CSRC_WHISPER_V3_MEL_BANK_H_\n"
with open("whisper-v3-mel-bank.h", "w") as f:
f.write(s)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,39 @@
#!/usr/bin/env python3
# Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
import librosa
import numpy as np
def main():
m = librosa.filters.mel(sr=16000, n_fft=400, n_mels=80)
assert m.shape == (80, 201)
s = "// Auto-generated. Do NOT edit!\n\n"
s += "// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)\n\n"
s += "\n"
s += "#ifndef KALDIFEAT_CSRC_WHISPER_MEL_BANK_H_\n"
s += "#define KALDIFEAT_CSRC_WHISPER_MEL_BANK_H_\n"
s += "namespace kaldifeat {\n\n"
s += f"constexpr int32_t kWhisperMelRows = {m.shape[0]};\n"
s += f"constexpr int32_t kWhisperMelCols = {m.shape[1]};\n"
s += "\n"
s += "constexpr float kWhisperMelArray[] = {\n"
sep = ""
for i, f in enumerate(m.reshape(-1).tolist()):
s += f"{sep}{f:.8f}"
sep = ", "
if i and i % 7 == 0:
s += ",\n"
sep = ""
s += "};\n\n"
s += "} // namespace kaldifeat\n\n"
s += "#endif // KALDIFEAT_CSRC_WHISPER_MEL_BANK_H_\n"
with open("whisper-mel-bank.h", "w") as f:
f.write(s)
if __name__ == "__main__":
main()

View File

@ -5,6 +5,7 @@
#ifndef KALDIFEAT_CSRC_LOG_H_
#define KALDIFEAT_CSRC_LOG_H_
#include <cstdint>
#include <cstdlib>
#include <iostream>
#include <sstream>

View File

@ -0,0 +1,45 @@
// kaldifeat/csrc/matrix-functions.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/matrix/matrix-functions.cc
#include "kaldifeat/csrc/matrix-functions.h"
#include <cmath>
#include "kaldifeat/csrc/log.h"
namespace kaldifeat {
void ComputeDctMatrix(torch::Tensor *mat) {
KALDIFEAT_ASSERT(mat->dim() == 2);
int32_t num_rows = mat->size(0);
int32_t num_cols = mat->size(1);
KALDIFEAT_ASSERT(num_rows == num_cols);
KALDIFEAT_ASSERT(num_rows > 0);
int32_t stride = mat->stride(0);
// normalizer for X_0
float normalizer = std::sqrt(1.0f / num_cols);
// mat[0, :] = normalizer
mat->index({0, "..."}) = normalizer;
// normalizer for other elements
normalizer = std::sqrt(2.0f / num_cols);
float *data = mat->data_ptr<float>();
for (int32_t r = 1; r < num_rows; ++r) {
float *this_row = data + r * stride;
for (int32_t c = 0; c < num_cols; ++c) {
float v = std::cos(static_cast<double>(M_PI) / num_cols * (c + 0.5) * r);
this_row[c] = normalizer * v;
}
}
}
} // namespace kaldifeat

View File

@ -0,0 +1,26 @@
// kaldifeat/csrc/matrix-functions.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/matrix/matrix-functions.h
#ifndef KALDIFEAT_CSRC_MATRIX_FUNCTIONS_H_
#define KALDIFEAT_CSRC_MATRIX_FUNCTIONS_H_
#include "torch/script.h"
namespace kaldifeat {
/// ComputeDctMatrix computes a matrix corresponding to the DCT, such that
/// M * v equals the DCT of vector v. M must be square at input.
/// This is the type = II DCT with normalization, corresponding to the
/// following equations, where x is the signal and X is the DCT:
/// X_0 = sqrt(1/N) \sum_{n = 0}^{N-1} x_n
/// X_k = sqrt(2/N) \sum_{n = 0}^{N-1} x_n cos( \pi/N (n + 1/2) k )
/// See also
/// https://docs.scipy.org/doc/scipy/reference/generated/scipy.fftpack.dct.html
void ComputeDctMatrix(torch::Tensor *M);
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_MATRIX_FUNCTIONS_H_

View File

@ -3,13 +3,20 @@
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
//
// This file is copied/modified from kaldi/src/feat/mel-computations.cc
//
#include "kaldifeat/csrc/mel-computations.h"
#include <algorithm>
#include "kaldifeat/csrc/feature-window.h"
namespace kaldifeat {
std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts) {
os << opts.ToString();
return os;
}
float MelBanks::VtlnWarpFreq(
float vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN.
float vtln_high_cutoff,
@ -83,7 +90,7 @@ float MelBanks::VtlnWarpMelFreq(
MelBanks::MelBanks(const MelBanksOptions &opts,
const FrameExtractionOptions &frame_opts,
float vtln_warp_factor)
float vtln_warp_factor, torch::Device device)
: htk_mode_(opts.htk_mode) {
int32_t num_bins = opts.num_bins;
if (num_bins < 3) KALDIFEAT_ERR << "Must have at least 3 mel bins";
@ -131,9 +138,14 @@ MelBanks::MelBanks(const MelBanksOptions &opts,
<< " and vtln-high " << vtln_high << ", versus "
<< "low-freq " << low_freq << " and high-freq " << high_freq;
// we will transpose bins_mat_ at the end of this function
bins_mat_ = torch::zeros({num_bins, num_fft_bins}, torch::kFloat);
int32_t stride = bins_mat_.strides()[0];
center_freqs_ = torch::empty({num_bins}, torch::kFloat);
float *center_freqs_data = center_freqs_.data_ptr<float>();
for (int32_t bin = 0; bin < num_bins; ++bin) {
float left_mel = mel_low_freq + bin * mel_freq_delta,
center_mel = mel_low_freq + (bin + 1) * mel_freq_delta,
@ -147,6 +159,7 @@ MelBanks::MelBanks(const MelBanksOptions &opts,
right_mel = VtlnWarpMelFreq(vtln_low, vtln_high, low_freq, high_freq,
vtln_warp_factor, right_mel);
}
center_freqs_data[bin] = InverseMelScale(center_mel);
// this_bin will be a vector of coefficients that is only
// nonzero where this mel bin is active.
float *this_bin = bins_mat_.data_ptr<float>() + bin * stride;
@ -166,15 +179,184 @@ MelBanks::MelBanks(const MelBanksOptions &opts,
last_index = i;
}
}
KALDIFEAT_ASSERT(first_index != -1 && last_index >= first_index &&
"You may have set num_mel_bins too large.");
// Note: It is possible that first_index == last_index == -1 at this line.
// Replicate a bug in HTK, for testing purposes.
if (opts.htk_mode && bin == 0 && mel_low_freq != 0.0f)
if (opts.htk_mode && bin == 0 && mel_low_freq != 0.0f &&
first_index != -1) {
this_bin[first_index] = 0.0f;
}
}
if (debug_) KALDIFEAT_LOG << bins_mat_;
bins_mat_.t_();
if (bins_mat_.device() != device) {
bins_mat_ = bins_mat_.to(device);
}
}
MelBanks::MelBanks(const float *weights, int32_t num_rows, int32_t num_cols,
torch::Device device)
: debug_(false), htk_mode_(false) {
bins_mat_ = torch::from_blob(const_cast<float *>(weights),
{num_rows, num_cols}, torch::kFloat)
.t()
.to(device);
}
torch::Tensor MelBanks::Compute(const torch::Tensor &spectrum) const {
return torch::mm(spectrum, bins_mat_);
}
void ComputeLifterCoeffs(float Q, torch::Tensor *coeffs) {
// Compute liftering coefficients (scaling on cepstral coeffs)
// coeffs are numbered slightly differently from HTK: the zeroth
// index is C0, which is not affected.
float *data = coeffs->data_ptr<float>();
int32_t n = coeffs->numel();
for (int32_t i = 0; i < n; ++i) {
data[i] = 1.0 + 0.5 * Q * sin(M_PI * i / Q);
}
}
void GetEqualLoudnessVector(const MelBanks &mel_banks, torch::Tensor *ans) {
int32_t n = mel_banks.NumBins();
// Central frequency of each mel bin.
const torch::Tensor &f0 = mel_banks.GetCenterFreqs();
const float *f0_data = f0.data_ptr<float>();
*ans = torch::empty({1, n}, torch::kFloat);
float *ans_data = ans->data_ptr<float>();
for (int32_t i = 0; i < n; ++i) {
float fsq = f0_data[i] * f0_data[i];
float fsub = fsq / (fsq + 1.6e5);
ans_data[i] = fsub * fsub * ((fsq + 1.44e6) / (fsq + 9.61e6));
}
}
// Durbin's recursion - converts autocorrelation coefficients to the LPC
// pTmp - temporal place [n]
// pAC - autocorrelation coefficients [n + 1]
// pLP - linear prediction coefficients [n]
// (predicted_sn = sum_1^P{a[i-1] * s[n-i]}})
// F(z) = 1 / (1 - A(z)), 1 is not stored in the denominator
static float Durbin(int n, const float *pAC, float *pLP, float *pTmp) {
float ki; // reflection coefficient
int i;
int j;
float E = pAC[0];
for (i = 0; i < n; ++i) {
// next reflection coefficient
ki = pAC[i + 1];
for (j = 0; j < i; ++j) ki += pLP[j] * pAC[i - j];
ki = ki / E;
// new error
float c = 1 - ki * ki;
if (c < 1.0e-5) // remove NaNs for constant signal
c = 1.0e-5;
E *= c;
// new LP coefficients
pTmp[i] = -ki;
for (j = 0; j < i; ++j) pTmp[j] = pLP[j] - ki * pLP[i - j - 1];
for (j = 0; j <= i; ++j) pLP[j] = pTmp[j];
}
return E;
}
// Compute LP coefficients from autocorrelation coefficients.
torch::Tensor ComputeLpc(const torch::Tensor &autocorr_in,
torch::Tensor *lpc_out) {
KALDIFEAT_ASSERT(autocorr_in.dim() == 2);
int32_t num_frames = autocorr_in.size(0);
int32_t lpc_order = autocorr_in.size(1) - 1;
*lpc_out = torch::empty({num_frames, lpc_order}, torch::kFloat);
torch::Tensor ans = torch::empty({num_frames}, torch::kFloat);
// TODO(fangjun): Durbin runs only on CPU. Implement a CUDA version
torch::Device saved_device = autocorr_in.device();
torch::Device cpu("cpu");
torch::Tensor in_cpu = autocorr_in.to(cpu);
torch::Tensor tmp = torch::empty_like(*lpc_out);
int32_t in_stride = in_cpu.stride(0);
int32_t ans_stride = ans.stride(0);
int32_t tmp_stride = tmp.stride(0);
int32_t lpc_stride = lpc_out->stride(0);
const float *in_data = in_cpu.data_ptr<float>();
float *ans_data = ans.data_ptr<float>();
float *tmp_data = tmp.data_ptr<float>();
float *lpc_data = lpc_out->data_ptr<float>();
// see
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Parallel.h#L58
at::parallel_for(0, num_frames, 1, [&](int32_t begin, int32_t end) -> void {
for (int32_t i = begin; i != end; ++i) {
float ret = Durbin(lpc_order, in_data + i * in_stride,
lpc_data + i * lpc_stride, tmp_data + i * tmp_stride);
if (ret <= 0.0) KALDIFEAT_WARN << "Zero energy in LPC computation";
ans_data[i] = -logf(1.0 / ret); // forms the C0 value
}
});
*lpc_out = lpc_out->to(saved_device);
return ans.to(saved_device);
}
static void Lpc2CepstrumInternal(int n, const float *pLPC, float *pCepst) {
for (int32_t i = 0; i < n; ++i) {
double sum = 0.0;
for (int32_t j = 0; j < i; ++j) {
sum += (i - j) * pLPC[j] * pCepst[i - j - 1];
}
pCepst[i] = -pLPC[i] - sum / (i + 1);
}
}
torch::Tensor Lpc2Cepstrum(const torch::Tensor &lpc) {
KALDIFEAT_ASSERT(lpc.dim() == 2);
torch::Device cpu("cpu");
torch::Device saved_device = lpc.device();
// TODO(fangjun): support cuda
torch::Tensor in_cpu = lpc.to(cpu);
int32_t num_frames = in_cpu.size(0);
int32_t lpc_order = in_cpu.size(1);
const float *in_data = in_cpu.data_ptr<float>();
int32_t in_stride = in_cpu.stride(0);
torch::Tensor ans = torch::zeros({num_frames, lpc_order}, torch::kFloat);
int32_t ans_stride = ans.stride(0);
float *ans_data = ans.data_ptr<float>();
at::parallel_for(0, num_frames, 1, [&](int32_t begin, int32_t end) -> void {
for (int32_t i = begin; i != end; ++i) {
Lpc2CepstrumInternal(lpc_order, in_data + i * in_stride,
ans_data + i * ans_stride);
}
});
return ans.to(saved_device);
}
} // namespace kaldifeat

View File

@ -5,6 +5,7 @@
// This file is copied/modified from kaldi/src/feat/mel-computations.h
#include <cmath>
#include <string>
#include "kaldifeat/csrc/feature-window.h"
@ -32,8 +33,23 @@ struct MelBanksOptions {
// Enables more exact compatibility with HTK, for testing purposes. Affects
// mel-energy flooring and reproduces a bug in HTK.
bool htk_mode = false;
std::string ToString() const {
std::ostringstream os;
os << "MelBanksOptions(";
os << "num_bins=" << num_bins << ", ";
os << "low_freq=" << low_freq << ", ";
os << "high_freq=" << high_freq << ", ";
os << "vtln_low=" << vtln_low << ", ";
os << "vtln_high=" << vtln_high << ", ";
os << "debug_mel=" << (debug_mel ? "True" : "False") << ", ";
os << "htk_mode=" << (htk_mode ? "True" : "False") << ")";
return os.str();
}
};
std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts);
class MelBanks {
public:
static inline float InverseMelScale(float mel_freq) {
@ -57,18 +73,72 @@ class MelBanks {
float vtln_warp_factor, float mel_freq);
MelBanks(const MelBanksOptions &opts,
const FrameExtractionOptions &frame_opts, float vtln_warp_factor);
const FrameExtractionOptions &frame_opts, float vtln_warp_factor,
torch::Device device);
int32_t NumBins() const { return static_cast<int32_t>(bins_mat_.sizes()[0]); }
// Initialize with a 2-d weights matrix
//
// Note: This constructor is for Whisper. It does not initialize
// center_freqs_.
//
// @param weights Pointer to the start address of the matrix
// @param num_rows It equals to number of mel bins
// @param num_cols It equals to (number of fft bins)/2+1
MelBanks(const float *weights, int32_t num_rows, int32_t num_cols,
torch::Device device);
// CAUTION: we save a transposed version of bins_mat_, so return size(1) here
int32_t NumBins() const { return static_cast<int32_t>(bins_mat_.size(1)); }
// returns vector of central freq of each bin; needed by plp code.
const torch::Tensor &GetCenterFreqs() const { return center_freqs_; }
torch::Tensor Compute(const torch::Tensor &spectrum) const;
// for debug only
const torch::Tensor &GetBinsMat() const { return bins_mat_; }
private:
// A 2-D matrix of shape [num_bins, num_fft_bins]
// A 2-D matrix. Its shape is NOT [num_bins, num_fft_bins]
// Its shape is [num_fft_bins, num_bins] for non-whisper.
// For whisper, its shape is [num_fft_bins/2+1, num_bins]
torch::Tensor bins_mat_;
// center frequencies of bins, numbered from 0 ... num_bins-1.
// Needed by GetCenterFreqs().
torch::Tensor center_freqs_; // It's always on CPU
bool debug_;
bool htk_mode_;
};
// Compute liftering coefficients (scaling on cepstral coeffs)
// coeffs are numbered slightly differently from HTK: the zeroth
// index is C0, which is not affected.
//
// coeffs is a 1-D float tensor
void ComputeLifterCoeffs(float Q, torch::Tensor *coeffs);
void GetEqualLoudnessVector(const MelBanks &mel_banks, torch::Tensor *ans);
/* Compute LP coefficients from autocorrelation coefficients.
*
* @param [in] autocorr_in A 2-D tensor. Each row is a frame. Its number of
* columns is lpc_order + 1
* @param [out] lpc_coeffs A 2-D tensor. On return, it has as many rows as the
* input tensor. Its number of columns is lpc_order.
*
* @return Returns log energy of residual in a 1-D tensor. It has as many
* elements as the number of rows in `autocorr_in`.
*/
torch::Tensor ComputeLpc(const torch::Tensor &autocorr_in,
torch::Tensor *lpc_coeffs);
/*
* @param [in] lpc It is the output argument `lpc_coeffs` in ComputeLpc().
*/
torch::Tensor Lpc2Cepstrum(const torch::Tensor &lpc);
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_MEL_COMPUTATIONS_H_

View File

@ -0,0 +1,89 @@
// kaldifeat/csrc/online-feature-itf.h
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/itf/online-feature-itf.h
#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
#define KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_
#include <utility>
#include <vector>
#include "torch/script.h"
namespace kaldifeat {
class OnlineFeatureInterface {
public:
virtual ~OnlineFeatureInterface() = default;
virtual int32_t Dim() const = 0; /// returns the feature dimension.
//
// Returns frame shift in seconds. Helps to estimate duration from frame
// counts.
virtual float FrameShiftInSeconds() const = 0;
/// Returns the total number of frames, since the start of the utterance, that
/// are now available. In an online-decoding context, this will likely
/// increase with time as more data becomes available.
virtual int32_t NumFramesReady() const = 0;
/// Returns true if this is the last frame. Frame indices are zero-based, so
/// the first frame is zero. IsLastFrame(-1) will return false, unless the
/// file is empty (which is a case that I'm not sure all the code will handle,
/// so be careful). This function may return false for some frame if we
/// haven't yet decided to terminate decoding, but later true if we decide to
/// terminate decoding. This function exists mainly to correctly handle end
/// effects in feature extraction, and is not a mechanism to determine how
/// many frames are in the decodable object (as it used to be, and for
/// backward compatibility, still is, in the Decodable interface).
virtual bool IsLastFrame(int32_t frame) const = 0;
/// Gets the feature vector for this frame. Before calling this for a given
/// frame, it is assumed that you called NumFramesReady() and it returned a
/// number greater than "frame". Otherwise this call will likely crash with
/// an assert failure. This function is not declared const, in case there is
/// some kind of caching going on, but most of the time it shouldn't modify
/// the class.
///
/// The returned tensor has shape (1, Dim()).
virtual torch::Tensor GetFrame(int32_t frame) = 0;
/// This is like GetFrame() but for a collection of frames. There is a
/// default implementation that just gets the frames one by one, but it
/// may be overridden for efficiency by child classes (since sometimes
/// it's more efficient to do things in a batch).
///
/// The returned tensor has shape (frames.size(), Dim()).
virtual std::vector<torch::Tensor> GetFrames(
const std::vector<int32_t> &frames) {
std::vector<torch::Tensor> features;
features.reserve(frames.size());
for (auto i : frames) {
torch::Tensor f = GetFrame(i);
features.push_back(std::move(f));
}
return features;
#if 0
return torch::cat(features, /*dim*/ 0);
#endif
}
/// This would be called from the application, when you get more wave data.
/// Note: the sampling_rate is typically only provided so the code can assert
/// that it matches the sampling rate expected in the options.
virtual void AcceptWaveform(float sampling_rate,
const torch::Tensor &waveform) = 0;
/// InputFinished() tells the class you won't be providing any
/// more waveform. This will help flush out the last few frames
/// of delta or LDA features (it will typically affect the return value
/// of IsLastFrame.
virtual void InputFinished() = 0;
};
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_ITF_H_

View File

@ -0,0 +1,49 @@
// kaldifeat/csrc/online-feature-test.h
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
#include "kaldifeat/csrc/online-feature.h"
#include "gtest/gtest.h"
namespace kaldifeat {
TEST(RecyclingVector, TestUnlimited) {
RecyclingVector v(-1);
constexpr int32_t N = 100;
for (int32_t i = 0; i != N; ++i) {
torch::Tensor t = torch::tensor({i, i + 1, i + 2});
v.PushBack(t);
}
ASSERT_EQ(v.Size(), N);
for (int32_t i = 0; i != N; ++i) {
torch::Tensor t = v.At(i);
torch::Tensor expected = torch::tensor({i, i + 1, i + 2});
EXPECT_TRUE(t.equal(expected));
}
}
TEST(RecyclingVector, Testlimited) {
constexpr int32_t K = 3;
constexpr int32_t N = 10;
RecyclingVector v(K);
for (int32_t i = 0; i != N; ++i) {
torch::Tensor t = torch::tensor({i, i + 1, i + 2});
v.PushBack(t);
}
ASSERT_EQ(v.Size(), N);
for (int32_t i = 0; i < N - K; ++i) {
ASSERT_DEATH(v.At(i), "");
}
for (int32_t i = N - K; i != N; ++i) {
torch::Tensor t = v.At(i);
torch::Tensor expected = torch::tensor({i, i + 1, i + 2});
EXPECT_TRUE(t.equal(expected));
}
}
} // namespace kaldifeat

View File

@ -0,0 +1,133 @@
// kaldifeat/csrc/online-feature.cc
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/online-feature.cc
#include "kaldifeat/csrc/online-feature.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/log.h"
namespace kaldifeat {
RecyclingVector::RecyclingVector(int32_t items_to_hold)
: items_to_hold_(items_to_hold == 0 ? -1 : items_to_hold),
first_available_index_(0) {}
torch::Tensor RecyclingVector::At(int32_t index) const {
if (index < first_available_index_) {
KALDIFEAT_ERR << "Attempted to retrieve feature vector that was "
"already removed by the RecyclingVector (index = "
<< index << "; "
<< "first_available_index = " << first_available_index_
<< "; "
<< "size = " << Size() << ")";
}
// 'at' does size checking.
return items_.at(index - first_available_index_);
}
void RecyclingVector::PushBack(torch::Tensor item) {
// Note: -1 is a larger number when treated as unsigned
if (items_.size() == static_cast<size_t>(items_to_hold_)) {
items_.pop_front();
++first_available_index_;
}
items_.push_back(item);
}
int32_t RecyclingVector::Size() const {
return first_available_index_ + static_cast<int32_t>(items_.size());
}
template <class C>
OnlineGenericBaseFeature<C>::OnlineGenericBaseFeature(
const typename C::Options &opts)
: computer_(opts),
window_function_(opts.frame_opts, opts.device),
features_(opts.frame_opts.max_feature_vectors),
input_finished_(false),
waveform_offset_(0) {}
template <class C>
void OnlineGenericBaseFeature<C>::AcceptWaveform(
float sampling_rate, const torch::Tensor &original_waveform) {
if (original_waveform.numel() == 0) return; // Nothing to do.
KALDIFEAT_ASSERT(original_waveform.dim() == 1);
KALDIFEAT_ASSERT(sampling_rate == computer_.GetFrameOptions().samp_freq);
if (input_finished_)
KALDIFEAT_ERR << "AcceptWaveform called after InputFinished() was called.";
if (waveform_remainder_.numel() == 0) {
waveform_remainder_ = original_waveform;
} else {
waveform_remainder_ =
torch::cat({waveform_remainder_, original_waveform}, /*dim*/ 0);
}
ComputeFeatures();
}
template <class C>
void OnlineGenericBaseFeature<C>::InputFinished() {
input_finished_ = true;
ComputeFeatures();
}
template <class C>
void OnlineGenericBaseFeature<C>::ComputeFeatures() {
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
int64_t num_samples_total = waveform_offset_ + waveform_remainder_.numel();
int32_t num_frames_old = features_.Size();
int32_t num_frames_new =
NumFrames(num_samples_total, frame_opts, input_finished_);
KALDIFEAT_ASSERT(num_frames_new >= num_frames_old);
// note: this online feature-extraction code does not support VTLN.
float vtln_warp = 1.0;
for (int32_t frame = num_frames_old; frame < num_frames_new; ++frame) {
torch::Tensor window =
ExtractWindow(waveform_offset_, waveform_remainder_, frame, frame_opts);
// TODO(fangjun): We can compute all feature frames at once
torch::Tensor this_feature =
computer_.ComputeFeatures(window.unsqueeze(0), vtln_warp);
features_.PushBack(this_feature);
}
// OK, we will now discard any portion of the signal that will not be
// necessary to compute frames in the future.
int64_t first_sample_of_next_frame =
FirstSampleOfFrame(num_frames_new, frame_opts);
int32_t samples_to_discard = first_sample_of_next_frame - waveform_offset_;
if (samples_to_discard > 0) {
// discard the leftmost part of the waveform that we no longer need.
int32_t new_num_samples = waveform_remainder_.numel() - samples_to_discard;
if (new_num_samples <= 0) {
// odd, but we'll try to handle it.
waveform_offset_ += waveform_remainder_.numel();
waveform_remainder_.resize_({0});
} else {
using torch::indexing::None;
using torch::indexing::Slice;
waveform_remainder_ =
waveform_remainder_.index({Slice(samples_to_discard, None)});
waveform_offset_ += samples_to_discard;
}
}
}
// instantiate the templates defined here for MFCC, PLP and filterbank classes.
template class OnlineGenericBaseFeature<Mfcc>;
template class OnlineGenericBaseFeature<Plp>;
template class OnlineGenericBaseFeature<Fbank>;
} // namespace kaldifeat

View File

@ -0,0 +1,127 @@
// kaldifeat/csrc/online-feature.h
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/online-feature.h
#ifndef KALDIFEAT_CSRC_ONLINE_FEATURE_H_
#define KALDIFEAT_CSRC_ONLINE_FEATURE_H_
#include <deque>
#include "kaldifeat/csrc/feature-fbank.h"
#include "kaldifeat/csrc/feature-mfcc.h"
#include "kaldifeat/csrc/feature-plp.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/online-feature-itf.h"
namespace kaldifeat {
/// This class serves as a storage for feature vectors with an option to limit
/// the memory usage by removing old elements. The deleted frames indices are
/// "remembered" so that regardless of the MAX_ITEMS setting, the user always
/// provides the indices as if no deletion was being performed.
/// This is useful when processing very long recordings which would otherwise
/// cause the memory to eventually blow up when the features are not being
/// removed.
class RecyclingVector {
public:
/// By default it does not remove any elements.
explicit RecyclingVector(int32_t items_to_hold = -1);
~RecyclingVector() = default;
RecyclingVector(const RecyclingVector &) = delete;
RecyclingVector &operator=(const RecyclingVector &) = delete;
torch::Tensor At(int32_t index) const;
void PushBack(torch::Tensor item);
/// This method returns the size as if no "recycling" had happened,
/// i.e. equivalent to the number of times the PushBack method has been
/// called.
int32_t Size() const;
private:
std::deque<torch::Tensor> items_;
int32_t items_to_hold_;
int32_t first_available_index_;
};
/// This is a templated class for online feature extraction;
/// it's templated on a class like MfccComputer or PlpComputer
/// that does the basic feature extraction.
template <class C>
class OnlineGenericBaseFeature : public OnlineFeatureInterface {
public:
// Constructor from options class
explicit OnlineGenericBaseFeature(const typename C::Options &opts);
int32_t Dim() const override { return computer_.Dim(); }
float FrameShiftInSeconds() const override {
return computer_.GetFrameOptions().frame_shift_ms / 1000.0f;
}
int32_t NumFramesReady() const override { return features_.Size(); }
// Note: IsLastFrame() will only ever return true if you have called
// InputFinished() (and this frame is the last frame).
bool IsLastFrame(int32_t frame) const override {
return input_finished_ && frame == NumFramesReady() - 1;
}
torch::Tensor GetFrame(int32_t frame) override { return features_.At(frame); }
// This would be called from the application, when you get
// more wave data. Note: the sampling_rate is only provided so
// the code can assert that it matches the sampling rate
// expected in the options.
void AcceptWaveform(float sampling_rate,
const torch::Tensor &waveform) override;
// InputFinished() tells the class you won't be providing any
// more waveform. This will help flush out the last frame or two
// of features, in the case where snip-edges == false; it also
// affects the return value of IsLastFrame().
void InputFinished() override;
private:
// This function computes any additional feature frames that it is possible to
// compute from 'waveform_remainder_', which at this point may contain more
// than just a remainder-sized quantity (because AcceptWaveform() appends to
// waveform_remainder_ before calling this function). It adds these feature
// frames to features_, and shifts off any now-unneeded samples of input from
// waveform_remainder_ while incrementing waveform_offset_ by the same amount.
void ComputeFeatures();
C computer_; // class that does the MFCC or PLP or filterbank computation
FeatureWindowFunction window_function_;
// features_ is the Mfcc or Plp or Fbank features that we have already
// computed.
RecyclingVector features_;
// True if the user has called "InputFinished()"
bool input_finished_;
// waveform_offset_ is the number of samples of waveform that we have
// already discarded, i.e. that were prior to 'waveform_remainder_'.
int64_t waveform_offset_;
// waveform_remainder_ is a short piece of waveform that we may need to keep
// after extracting all the whole frames we can (whatever length of feature
// will be required for the next phase of computation).
// It is a 1-D tensor
torch::Tensor waveform_remainder_;
};
using OnlineMfcc = OnlineGenericBaseFeature<Mfcc>;
using OnlinePlp = OnlineGenericBaseFeature<Plp>;
using OnlineFbank = OnlineGenericBaseFeature<Fbank>;
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_ONLINE_FEATURE_H_

View File

@ -0,0 +1,154 @@
// kaldifeat/csrc/pitch-functions.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/pitch-functions.h
#ifndef KALDIFEAT_CSRC_PITCH_FUNCTIONS_H_
#define KALDIFEAT_CSRC_PITCH_FUNCTIONS_H_
// References
//
// Talkin, David, and W. Bastiaan Kleijn. "A robust algorithm for pitch
// tracking (RAPT)." coding and synthesis 495 (1995): 518.
// (https://www.ee.columbia.edu/~dpwe/papers/Talkin95-rapt.pdf)
//
// Ghahremani, Pegah, et al. "A pitch extraction algorithm tuned for
// automatic speech recognition." 2014 IEEE international conference on
// acoustics, speech and signal processing (ICASSP). IEEE, 2014.
// (http://danielpovey.com/files/2014_icassp_pitch.pdf)
#include <string>
#include "torch/script.h"
namespace kaldifeat {
struct PitchExtractionOptions {
// sample frequency in hertz
// must match the waveform file
float samp_freq = 16000;
float frame_shift_ms = 10.0; // in milliseconds.
float frame_length_ms = 25.0; // in milliseconds.
// Preemphasis coefficient. [use is deprecated.]
float preemph_coeff = 0.0;
float min_f0 = 50; // min f0 to search (Hz)
float max_f0 = 400; // max f0 to search (Hz)
float soft_min_f0 = 10.0; // Minimum f0, applied in soft way, must not
// exceed min-f0
float penalty_factor = 0.1; // cost factor for FO change
float lowpass_cutoff = 1000; // cutoff frequency for Low pass filter (Hz)
// Integer that determines filter width when
// upsampling NCCF
// Frequency that we down-sample the signal to. Must be
// more than twice lowpass-cutoff
float resample_freq = 4000;
float delta_pitch = 0.005; // the pitch tolerance in pruning lags
float nccf_ballast = 7000; // Increasing this factor reduces NCCF for
// quiet frames, helping ensure pitch
// continuity in unvoiced region
int32_t lowpass_filter_width = 1; // Integer that determines filter width of
// lowpass filter
int32_t upsample_filter_width = 5; // Integer that determines filter width
// when upsampling NCCF
// Below are newer config variables, not present in the original paper,
// that relate to the online pitch extraction algorithm.
// The maximum number of frames of latency that we allow the pitch-processing
// to introduce, for online operation. If you set this to a large value,
// there would be no inaccuracy from the Viterbi traceback (but it might make
// you wait to see the pitch). This is not very relevant for the online
// operation: normalization-right-context is more relevant, you
// can just leave this value at zero.
int32_t max_frames_latency = 0;
// Only relevant for the function ComputeKaldiPitch which is called by
// compute-kaldi-pitch-feats. If nonzero, we provide the input as chunks of
// this size. This affects the energy normalization which has a small effect
// on the resulting features, especially at the beginning of a file. For best
// compatibility with online operation (e.g. if you plan to train models for
// the online-deocding setup), you might want to set this to a small value,
// like one frame.
int32_t frames_per_chunk = 0;
// Only relevant for the function ComputeKaldiPitch which is called by
// compute-kaldi-pitch-feats, and only relevant if frames_per_chunk is
// nonzero. If true, it will query the features as soon as they are
// available, which simulates the first-pass features you would get in online
// decoding. If false, the features you will get will be the same as those
// available at the end of the utterance, after InputFinished() has been
// called: e.g. during lattice rescoring.
bool simulate_first_pass_online = false;
// Only relevant for online operation or when emulating online operation
// (e.g. when setting frames_per_chunk). This is the frame-index on which we
// recompute the NCCF (e.g. frame-index 500 = after 5 seconds); if the
// segment ends before this we do it when the segment ends. We do this by
// re-computing the signal average energy, which affects the NCCF via the
// "ballast term", scaling the resampled NCCF by a factor derived from the
// average change in the "ballast term", and re-doing the backtrace
// computation. Making this infinity would be the most exact, but would
// introduce unwanted latency at the end of long utterances, for little
// benefit.
int32_t recompute_frame = 500;
// This is a "hidden config" used only for testing the online pitch
// extraction. If true, we compute the signal root-mean-squared for the
// ballast term, only up to the current frame, rather than the end of the
// current chunk of signal. This makes the output insensitive to the
// chunking, which is useful for testing purposes.
bool nccf_ballast_online = false;
bool snip_edges = true;
torch::Device device{"cpu"};
PitchExtractionOptions() = default;
/// Returns the window-size in samples, after resampling. This is the
/// "basic window size", not the full window size after extending by max-lag.
// Because of floating point representation, it is more reliable to divide
// by 1000 instead of multiplying by 0.001, but it is a bit slower.
int32_t NccfWindowSize() const {
return static_cast<int32_t>(resample_freq * frame_length_ms / 1000.0);
}
/// Returns the window-shift in samples, after resampling.
int32_t NccfWindowShift() const {
return static_cast<int32_t>(resample_freq * frame_shift_ms / 1000.0);
}
std::string ToString() const {
std::ostringstream os;
os << "samp_freq: " << samp_freq << "\n";
os << "frame_shift_ms: " << frame_shift_ms << "\n";
os << "frame_length_ms: " << frame_length_ms << "\n";
os << "preemph_coeff: " << preemph_coeff << "\n";
os << "min_f0: " << min_f0 << "\n";
os << "max_f0: " << max_f0 << "\n";
os << "soft_min_f0: " << soft_min_f0 << "\n";
os << "penalty_factor: " << penalty_factor << "\n";
os << "lowpass_cutoff: " << lowpass_cutoff << "\n";
os << "resample_freq: " << resample_freq << "\n";
os << "delta_pitch: " << delta_pitch << "\n";
os << "nccf_ballast: " << nccf_ballast << "\n";
os << "lowpass_filter_width: " << lowpass_filter_width << "\n";
os << "upsample_filter_width: " << upsample_filter_width << "\n";
os << "max_frames_latency: " << max_frames_latency << "\n";
os << "frames_per_chunk: " << frames_per_chunk << "\n";
os << "simulate_first_pass_online: " << simulate_first_pass_online << "\n";
os << "recompute_frame: " << recompute_frame << "\n";
os << "nccf_ballast_online: " << nccf_ballast_online << "\n";
os << "snip_edges: " << snip_edges << "\n";
os << "device: " << device << "\n";
}
};
// TODO(fangjun): Implement it
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_PITCH_FUNCTIONS_H_

View File

@ -0,0 +1,82 @@
// kaldifeat/csrc/test_kaldifeat.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#include "torch/all.h"
#include "torch/script.h"
static void TestPreemph() {
torch::Tensor a = torch::arange(0, 12).reshape({3, 4}).to(torch::kFloat);
torch::Tensor b =
a.index({"...", torch::indexing::Slice(1, torch::indexing::None,
torch::indexing::None)});
torch::Tensor c = a.index({"...", torch::indexing::Slice(0, -1, 1)});
a.index({"...", torch::indexing::Slice(1, torch::indexing::None,
torch::indexing::None)}) =
b - 0.97 * c;
a.index({"...", 0}) *= 0.97;
std::cout << a << "\n";
std::cout << b << "\n";
std::cout << "c: \n" << c << "\n";
torch::Tensor d = b - 0.97 * c;
std::cout << d << "\n";
}
static void TestPad() {
torch::Tensor a = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat);
torch::Tensor b = torch::nn::functional::pad(
a, torch::nn::functional::PadFuncOptions({0, 3})
.mode(torch::kConstant)
.value(0));
std::cout << a << "\n";
std::cout << b << "\n";
}
static void TestGetStrided() {
// 0 1 2 3 4 5
//
//
// 0 1 2 3
// 2 3 4 5
torch::Tensor a = torch::arange(0, 6).to(torch::kFloat);
torch::Tensor b = a.as_strided({2, 4}, {2, 1});
// b = b.clone();
std::cout << a << "\n";
std::cout << b << "\n";
std::cout << b.mean(1).unsqueeze(1) << "\n";
b = b - b.mean(1).unsqueeze(1);
std::cout << a << "\n";
std::cout << b << "\n";
}
static void TestDither() {
torch::Tensor a = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat);
torch::Tensor b = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat) * 0.1;
std::cout << a << "\n";
std::cout << b << "\n";
std::cout << (a + b * 2) << "\n";
}
static void TestCat() {
torch::Tensor a = torch::arange(0, 6).reshape({2, 3}).to(torch::kFloat);
torch::Tensor b = torch::arange(0, 2).reshape({2, 1}).to(torch::kFloat) * 0.1;
torch::Tensor c = torch::cat({a, b}, 1);
torch::Tensor d = torch::cat({b, a}, 1);
torch::Tensor e = torch::cat({a, a}, 0);
std::cout << a << "\n";
std::cout << b << "\n";
std::cout << c << "\n";
std::cout << d << "\n";
std::cout << e << "\n";
}
int main() {
TestCat();
return 0;
}

View File

@ -0,0 +1,88 @@
/**
* Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kaldifeat/csrc/whisper-fbank.h"
#include <cmath>
#include <vector>
#include "kaldifeat/csrc/mel-computations.h"
#include "kaldifeat/csrc/whisper-mel-bank.h"
#include "kaldifeat/csrc/whisper-v3-mel-bank.h"
#ifndef M_2PI
#define M_2PI 6.283185307179586476925286766559005
#endif
namespace kaldifeat {
WhisperFbankComputer::WhisperFbankComputer(const WhisperFbankOptions &opts)
: opts_(opts) {
if (opts.num_mels == 80) {
mel_banks_ = std::make_unique<MelBanks>(kWhisperMelArray, kWhisperMelRows,
kWhisperMelCols, opts.device);
} else if (opts.num_mels == 128) {
mel_banks_ = std::make_unique<MelBanks>(
kWhisperV3MelArray, kWhisperV3MelRows, kWhisperV3MelCols, opts.device);
} else {
KALDIFEAT_ERR << "Unsupported num_mels: " << opts.num_mels
<< ". Support only 80 and 128";
}
opts_.frame_opts.samp_freq = 16000;
opts_.frame_opts.frame_shift_ms = 10;
opts_.frame_opts.frame_length_ms = 25;
opts_.frame_opts.dither = 0;
opts_.frame_opts.preemph_coeff = 0;
opts_.frame_opts.remove_dc_offset = false;
opts_.frame_opts.window_type = "hann";
opts_.frame_opts.round_to_power_of_two = false;
opts_.frame_opts.snip_edges = false;
}
torch::Tensor WhisperFbankComputer::Compute(
torch::Tensor /*signal_raw_log_energy*/, float /*vtln_warp*/,
const torch::Tensor &signal_frame) {
KALDIFEAT_ASSERT(signal_frame.dim() == 2);
KALDIFEAT_ASSERT(signal_frame.size(1) == opts_.frame_opts.PaddedWindowSize());
// note spectrum is in magnitude, not power, because of `abs()`
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
// signal_frame shape: [x, 512]
// power shape [x, 257]
torch::Tensor power = torch::fft::rfft(signal_frame).abs().pow(2);
#else
// signal_frame shape [x, 512]
// real_imag shape [x, 257, 2],
// where [..., 0] is the real part
// [..., 1] is the imaginary part
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
torch::Tensor real = real_imag.index({"...", 0});
torch::Tensor imag = real_imag.index({"...", 1});
torch::Tensor power = (real.square() + imag.square());
#endif
torch::Tensor mel_energies = mel_banks_->Compute(power);
torch::Tensor log_spec = torch::clamp_min(mel_energies, 1e-10).log10();
log_spec = torch::maximum(log_spec, log_spec.max() - 8.0);
torch::Tensor mel = (log_spec + 4.0) / 4.0;
return mel;
}
} // namespace kaldifeat

View File

@ -0,0 +1,78 @@
/**
* Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef KALDIFEAT_CSRC_WHISPER_FBANK_H_
#define KALDIFEAT_CSRC_WHISPER_FBANK_H_
#include <memory>
#include <string>
#include <vector>
#include "kaldifeat/csrc/feature-common.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/mel-computations.h"
namespace kaldifeat {
struct WhisperFbankOptions {
FrameExtractionOptions frame_opts;
// for large v3, please use 128
int32_t num_mels = 80;
torch::Device device{"cpu"};
std::string ToString() const {
std::ostringstream os;
os << "WhisperFbankOptions(";
os << "frame_opts=" << frame_opts.ToString() << ", ";
os << "num_mels=" << num_mels << ", ";
os << "device=\"" << device << "\")";
return os.str();
}
};
class WhisperFbankComputer {
public:
// note: Only frame_opts.device is used. All other fields from frame_opts
// are ignored
explicit WhisperFbankComputer(const WhisperFbankOptions &opts = {});
int32_t Dim() const { return opts_.num_mels; }
const FrameExtractionOptions &GetFrameOptions() const {
return opts_.frame_opts;
}
const WhisperFbankOptions &GetOptions() const { return opts_; }
torch::Tensor Compute(torch::Tensor /*signal_raw_log_energy*/,
float /*vtln_warp*/, const torch::Tensor &signal_frame);
// if true, compute log_energy_pre_window but after dithering and dc removal
bool NeedRawLogEnergy() const { return false; }
using Options = WhisperFbankOptions;
private:
WhisperFbankOptions opts_;
std::unique_ptr<MelBanks> mel_banks_;
};
using WhisperFbank = OfflineFeatureTpl<WhisperFbankComputer>;
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_WHISPER_FBANK_H_

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,5 @@
add_subdirectory(csrc)
if(kaldifeat_BUILD_TESTS)
add_subdirectory(tests)
endif()

View File

@ -0,0 +1,40 @@
add_definitions(-DTORCH_API_INCLUDE_EXTENSION_H)
pybind11_add_module(_kaldifeat
feature-fbank.cc
feature-mfcc.cc
feature-plp.cc
feature-spectrogram.cc
feature-window.cc
kaldifeat.cc
mel-computations.cc
online-feature.cc
utils.cc
whisper-fbank.cc
)
if(APPLE)
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE PYTHON_SITE_PACKAGE_DIR
)
message(STATUS "PYTHON_SITE_PACKAGE_DIR: ${PYTHON_SITE_PACKAGE_DIR}")
target_link_libraries(_kaldifeat PRIVATE "-Wl,-rpath,${PYTHON_SITE_PACKAGE_DIR}")
endif()
if(NOT WIN32)
target_link_libraries(_kaldifeat PRIVATE "-Wl,-rpath,${kaldifeat_rpath_origin}/kaldifeat/${CMAKE_INSTALL_LIBDIR}")
endif()
target_link_libraries(_kaldifeat PRIVATE kaldifeat_core)
if(UNIX AND NOT APPLE)
target_link_libraries(_kaldifeat PUBLIC ${TORCH_DIR}/lib/libtorch_python.so)
# target_link_libraries(_kaldifeat PUBLIC ${PYTHON_LIBRARY})
elseif(WIN32)
target_link_libraries(_kaldifeat PUBLIC ${TORCH_DIR}/lib/torch_python.lib)
# target_link_libraries(_kaldifeat PUBLIC ${PYTHON_LIBRARIES})
endif()
install(TARGETS _kaldifeat
DESTINATION ../
)

View File

@ -0,0 +1 @@
filter=-runtime/references

View File

@ -0,0 +1,100 @@
// kaldifeat/python/csrc/feature-fbank.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#include "kaldifeat/python/csrc/feature-fbank.h"
#include <memory>
#include <string>
#include "kaldifeat/csrc/feature-fbank.h"
#include "kaldifeat/python/csrc/utils.h"
namespace kaldifeat {
static void PybindFbankOptions(py::module &m) {
using PyClass = FbankOptions;
py::class_<PyClass>(m, "FbankOptions")
.def(py::init<>())
.def(py::init([](const MelBanksOptions &mel_opts,
const FrameExtractionOptions &frame_opts =
FrameExtractionOptions(),
bool use_energy = false, float energy_floor = 0.0f,
bool raw_energy = true, bool htk_compat = false,
bool use_log_fbank = true, bool use_power = true,
py::object device =
py::str("cpu")) -> std::unique_ptr<FbankOptions> {
auto opts = std::make_unique<FbankOptions>();
opts->frame_opts = frame_opts;
opts->mel_opts = mel_opts;
opts->use_energy = use_energy;
opts->energy_floor = energy_floor;
opts->raw_energy = raw_energy;
opts->htk_compat = htk_compat;
opts->use_log_fbank = use_log_fbank;
opts->use_power = use_power;
std::string s = static_cast<py::str>(device);
opts->device = torch::Device(s);
return opts;
}),
py::arg("mel_opts"),
py::arg("frame_opts") = FrameExtractionOptions(),
py::arg("use_energy") = false, py::arg("energy_floor") = 0.0f,
py::arg("raw_energy") = true, py::arg("htk_compat") = false,
py::arg("use_log_fbank") = true, py::arg("use_power") = true,
py::arg("device") = py::str("cpu"))
.def_readwrite("frame_opts", &PyClass::frame_opts)
.def_readwrite("mel_opts", &PyClass::mel_opts)
.def_readwrite("use_energy", &PyClass::use_energy)
.def_readwrite("energy_floor", &PyClass::energy_floor)
.def_readwrite("raw_energy", &PyClass::raw_energy)
.def_readwrite("htk_compat", &PyClass::htk_compat)
.def_readwrite("use_log_fbank", &PyClass::use_log_fbank)
.def_readwrite("use_power", &PyClass::use_power)
.def_property(
"device",
[](const PyClass &self) -> py::object {
py::object ans = py::module_::import("torch").attr("device");
return ans(self.device.str());
},
[](PyClass &self, py::object obj) -> void {
std::string s = static_cast<py::str>(obj);
self.device = torch::Device(s);
})
.def("__str__",
[](const PyClass &self) -> std::string { return self.ToString(); })
.def("as_dict",
[](const PyClass &self) -> py::dict { return AsDict(self); })
.def_static(
"from_dict",
[](py::dict dict) -> PyClass { return FbankOptionsFromDict(dict); })
.def(py::pickle(
[](const PyClass &self) -> py::dict { return AsDict(self); },
[](py::dict dict) -> PyClass { return FbankOptionsFromDict(dict); }));
}
static void PybindFbank(py::module &m) {
using PyClass = Fbank;
py::class_<PyClass>(m, "Fbank")
.def(py::init<const FbankOptions &>(), py::arg("opts"))
.def("dim", &PyClass::Dim)
.def_property_readonly("options", &PyClass::GetOptions)
.def("compute_features", &PyClass::ComputeFeatures, py::arg("wave"),
py::arg("vtln_warp"), py::call_guard<py::gil_scoped_release>())
.def(py::pickle(
[](const PyClass &self) -> py::dict {
return AsDict(self.GetOptions());
},
[](py::dict dict) -> std::unique_ptr<PyClass> {
return std::make_unique<PyClass>(FbankOptionsFromDict(dict));
}));
}
void PybindFeatureFbank(py::module &m) {
PybindFbankOptions(m);
PybindFbank(m);
}
} // namespace kaldifeat

View File

@ -0,0 +1,16 @@
// kaldifeat/python/csrc/feature-fbank.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#ifndef KALDIFEAT_PYTHON_CSRC_FEATURE_FBANK_H_
#define KALDIFEAT_PYTHON_CSRC_FEATURE_FBANK_H_
#include "kaldifeat/python/csrc/kaldifeat.h"
namespace kaldifeat {
void PybindFeatureFbank(py::module &m);
} // namespace kaldifeat
#endif // KALDIFEAT_PYTHON_CSRC_FEATURE_FBANK_H_

View File

@ -0,0 +1,100 @@
// kaldifeat/python/csrc/feature-mfcc.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#include "kaldifeat/python/csrc/feature-mfcc.h"
#include <memory>
#include <string>
#include "kaldifeat/csrc/feature-mfcc.h"
#include "kaldifeat/python/csrc/utils.h"
namespace kaldifeat {
void PybindMfccOptions(py::module &m) {
using PyClass = MfccOptions;
py::class_<PyClass>(m, "MfccOptions")
.def(py::init<>())
.def(py::init([](const MelBanksOptions &mel_opts,
const FrameExtractionOptions &frame_opts =
FrameExtractionOptions(),
int32_t num_ceps = 13, bool use_energy = true,
float energy_floor = 0.0, bool raw_energy = true,
float cepstral_lifter = 22.0, bool htk_compat = false,
py::object device =
py::str("cpu")) -> std::unique_ptr<MfccOptions> {
auto opts = std::make_unique<MfccOptions>();
opts->frame_opts = frame_opts;
opts->mel_opts = mel_opts;
opts->num_ceps = num_ceps;
opts->use_energy = use_energy;
opts->energy_floor = energy_floor;
opts->raw_energy = raw_energy;
opts->cepstral_lifter = cepstral_lifter;
opts->htk_compat = htk_compat;
std::string s = static_cast<py::str>(device);
opts->device = torch::Device(s);
return opts;
}),
py::arg("mel_opts"),
py::arg("frame_opts") = FrameExtractionOptions(),
py::arg("num_ceps") = 13, py::arg("use_energy") = true,
py::arg("energy_floor") = 0.0f, py::arg("raw_energy") = true,
py::arg("cepstral_lifter") = 22.0, py::arg("htk_compat") = false,
py::arg("device") = py::str("cpu"))
.def_readwrite("frame_opts", &PyClass::frame_opts)
.def_readwrite("mel_opts", &PyClass::mel_opts)
.def_readwrite("num_ceps", &PyClass::num_ceps)
.def_readwrite("use_energy", &PyClass::use_energy)
.def_readwrite("energy_floor", &PyClass::energy_floor)
.def_readwrite("raw_energy", &PyClass::raw_energy)
.def_readwrite("cepstral_lifter", &PyClass::cepstral_lifter)
.def_readwrite("htk_compat", &PyClass::htk_compat)
.def_property(
"device",
[](const PyClass &self) -> py::object {
py::object ans = py::module_::import("torch").attr("device");
return ans(self.device.str());
},
[](PyClass &self, py::object obj) -> void {
std::string s = static_cast<py::str>(obj);
self.device = torch::Device(s);
})
.def("__str__",
[](const PyClass &self) -> std::string { return self.ToString(); })
.def("as_dict",
[](const PyClass &self) -> py::dict { return AsDict(self); })
.def_static(
"from_dict",
[](py::dict dict) -> PyClass { return MfccOptionsFromDict(dict); })
.def(py::pickle(
[](const PyClass &self) -> py::dict { return AsDict(self); },
[](py::dict dict) -> PyClass { return MfccOptionsFromDict(dict); }));
}
static void PybindMfcc(py::module &m) {
using PyClass = Mfcc;
py::class_<PyClass>(m, "Mfcc")
.def(py::init<const MfccOptions &>(), py::arg("opts"))
.def("dim", &PyClass::Dim)
.def_property_readonly("options", &PyClass::GetOptions)
.def("compute_features", &PyClass::ComputeFeatures, py::arg("wave"),
py::arg("vtln_warp"), py::call_guard<py::gil_scoped_release>())
.def(py::pickle(
[](const PyClass &self) -> py::dict {
return AsDict(self.GetOptions());
},
[](py::dict dict) -> std::unique_ptr<PyClass> {
return std::make_unique<PyClass>(MfccOptionsFromDict(dict));
}));
}
void PybindFeatureMfcc(py::module &m) {
PybindMfccOptions(m);
PybindMfcc(m);
}
} // namespace kaldifeat

View File

@ -0,0 +1,16 @@
// kaldifeat/python/csrc/feature-mfcc.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#ifndef KALDIFEAT_PYTHON_CSRC_FEATURE_MFCC_H_
#define KALDIFEAT_PYTHON_CSRC_FEATURE_MFCC_H_
#include "kaldifeat/python/csrc/kaldifeat.h"
namespace kaldifeat {
void PybindFeatureMfcc(py::module &m);
} // namespace kaldifeat
#endif // KALDIFEAT_PYTHON_CSRC_FEATURE_MFCC_H_

View File

@ -0,0 +1,109 @@
// kaldifeat/python/csrc/feature-plp.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#include "kaldifeat/python/csrc/feature-plp.h"
#include <memory>
#include <string>
#include "kaldifeat/csrc/feature-plp.h"
#include "kaldifeat/python/csrc/utils.h"
namespace kaldifeat {
void PybindPlpOptions(py::module &m) {
using PyClass = PlpOptions;
py::class_<PyClass>(m, "PlpOptions")
.def(py::init<>())
.def(py::init([](const MelBanksOptions &mel_opts,
const FrameExtractionOptions &frame_opts =
FrameExtractionOptions(),
int32_t lpc_order = 12, int32_t num_ceps = 13,
bool use_energy = true, float energy_floor = 0.0,
bool raw_energy = true, float compress_factor = 0.33333,
int32_t cepstral_lifter = 22, float cepstral_scale = 1.0,
bool htk_compat = false,
py::object device =
py::str("cpu")) -> std::unique_ptr<PlpOptions> {
auto opts = std::make_unique<PlpOptions>();
opts->frame_opts = frame_opts;
opts->mel_opts = mel_opts;
opts->lpc_order = lpc_order;
opts->num_ceps = num_ceps;
opts->use_energy = use_energy;
opts->energy_floor = energy_floor;
opts->raw_energy = raw_energy;
opts->compress_factor = compress_factor;
opts->cepstral_lifter = cepstral_lifter;
opts->cepstral_scale = cepstral_scale;
opts->htk_compat = htk_compat;
std::string s = static_cast<py::str>(device);
opts->device = torch::Device(s);
return opts;
}),
py::arg("mel_opts"),
py::arg("frame_opts") = FrameExtractionOptions(),
py::arg("lpc_order") = 12, py::arg("num_ceps") = 13,
py::arg("use_energy") = true, py::arg("energy_floor") = 0.0,
py::arg("raw_energy") = true, py::arg("compress_factor") = 0.33333,
py::arg("cepstral_lifter") = 22, py::arg("cepstral_scale") = 1.0,
py::arg("htk_compat") = false, py::arg("device") = py::str("cpu"))
.def_readwrite("frame_opts", &PyClass::frame_opts)
.def_readwrite("mel_opts", &PyClass::mel_opts)
.def_readwrite("lpc_order", &PyClass::lpc_order)
.def_readwrite("num_ceps", &PyClass::num_ceps)
.def_readwrite("use_energy", &PyClass::use_energy)
.def_readwrite("energy_floor", &PyClass::energy_floor)
.def_readwrite("raw_energy", &PyClass::raw_energy)
.def_readwrite("compress_factor", &PyClass::compress_factor)
.def_readwrite("cepstral_lifter", &PyClass::cepstral_lifter)
.def_readwrite("cepstral_scale", &PyClass::cepstral_scale)
.def_readwrite("htk_compat", &PyClass::htk_compat)
.def_property(
"device",
[](const PyClass &self) -> py::object {
py::object ans = py::module_::import("torch").attr("device");
return ans(self.device.str());
},
[](PyClass &self, py::object obj) -> void {
std::string s = static_cast<py::str>(obj);
self.device = torch::Device(s);
})
.def("__str__",
[](const PyClass &self) -> std::string { return self.ToString(); })
.def("as_dict",
[](const PyClass &self) -> py::dict { return AsDict(self); })
.def_static(
"from_dict",
[](py::dict dict) -> PyClass { return PlpOptionsFromDict(dict); })
.def(py::pickle(
[](const PyClass &self) -> py::dict { return AsDict(self); },
[](py::dict dict) -> PyClass { return PlpOptionsFromDict(dict); }));
}
static void PybindPlp(py::module &m) {
using PyClass = Plp;
py::class_<PyClass>(m, "Plp")
.def(py::init<const PlpOptions &>(), py::arg("opts"))
.def("dim", &PyClass::Dim)
.def_property_readonly("options", &PyClass::GetOptions)
.def("compute_features", &PyClass::ComputeFeatures, py::arg("wave"),
py::arg("vtln_warp"), py::call_guard<py::gil_scoped_release>())
.def(py::pickle(
[](const PyClass &self) -> py::dict {
return AsDict(self.GetOptions());
},
[](py::dict dict) -> std::unique_ptr<PyClass> {
return std::make_unique<PyClass>(PlpOptionsFromDict(dict));
}));
}
void PybindFeaturePlp(py::module &m) {
PybindPlpOptions(m);
PybindPlp(m);
}
} // namespace kaldifeat

View File

@ -0,0 +1,16 @@
// kaldifeat/python/csrc/feature-plp.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#ifndef KALDIFEAT_PYTHON_CSRC_FEATURE_PLP_H_
#define KALDIFEAT_PYTHON_CSRC_FEATURE_PLP_H_
#include "kaldifeat/python/csrc/kaldifeat.h"
namespace kaldifeat {
void PybindFeaturePlp(py::module &m);
} // namespace kaldifeat
#endif // KALDIFEAT_PYTHON_CSRC_FEATURE_PLP_H_

View File

@ -0,0 +1,91 @@
// kaldifeat/python/csrc/feature-spectrogram.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#include "kaldifeat/python/csrc/feature-spectrogram.h"
#include <memory>
#include <string>
#include "kaldifeat/csrc/feature-spectrogram.h"
#include "kaldifeat/python/csrc/utils.h"
namespace kaldifeat {
static void PybindSpectrogramOptions(py::module &m) {
using PyClass = SpectrogramOptions;
py::class_<PyClass>(m, "SpectrogramOptions")
.def(py::init([](const FrameExtractionOptions &frame_opts =
FrameExtractionOptions(),
float energy_floor = 0.0, bool raw_energy = true,
bool return_raw_fft = false,
py::object device = py::str(
"cpu")) -> std::unique_ptr<SpectrogramOptions> {
auto opts = std::make_unique<SpectrogramOptions>();
opts->frame_opts = frame_opts;
opts->energy_floor = energy_floor;
opts->raw_energy = raw_energy;
opts->return_raw_fft = return_raw_fft;
std::string s = static_cast<py::str>(device);
opts->device = torch::Device(s);
return opts;
}),
py::arg("frame_opts") = FrameExtractionOptions(),
py::arg("energy_floor") = 0.0, py::arg("raw_energy") = true,
py::arg("return_raw_fft") = false,
py::arg("device") = py::str("cpu"))
.def_readwrite("frame_opts", &PyClass::frame_opts)
.def_readwrite("energy_floor", &PyClass::energy_floor)
.def_readwrite("raw_energy", &PyClass::raw_energy)
// .def_readwrite("return_raw_fft", &PyClass::return_raw_fft) // not
// implemented yet
.def_property(
"device",
[](const PyClass &self) -> py::object {
py::object ans = py::module_::import("torch").attr("device");
return ans(self.device.str());
},
[](PyClass &self, py::object obj) -> void {
std::string s = static_cast<py::str>(obj);
self.device = torch::Device(s);
})
.def("__str__",
[](const PyClass &self) -> std::string { return self.ToString(); })
.def("as_dict",
[](const PyClass &self) -> py::dict { return AsDict(self); })
.def_static("from_dict",
[](py::dict dict) -> PyClass {
return SpectrogramOptionsFromDict(dict);
})
.def(py::pickle(
[](const PyClass &self) -> py::dict { return AsDict(self); },
[](py::dict dict) -> PyClass {
return SpectrogramOptionsFromDict(dict);
}));
}
static void PybindSpectrogram(py::module &m) {
using PyClass = Spectrogram;
py::class_<PyClass>(m, "Spectrogram")
.def(py::init<const SpectrogramOptions &>(), py::arg("opts"))
.def("dim", &PyClass::Dim)
.def_property_readonly("options", &PyClass::GetOptions)
.def("compute_features", &PyClass::ComputeFeatures, py::arg("wave"),
py::arg("vtln_warp"), py::call_guard<py::gil_scoped_release>())
.def(py::pickle(
[](const PyClass &self) -> py::dict {
return AsDict(self.GetOptions());
},
[](py::dict dict) -> std::unique_ptr<PyClass> {
return std::make_unique<PyClass>(SpectrogramOptionsFromDict(dict));
}));
}
void PybindFeatureSpectrogram(py::module &m) {
PybindSpectrogramOptions(m);
PybindSpectrogram(m);
}
} // namespace kaldifeat

View File

@ -0,0 +1,16 @@
// kaldifeat/python/csrc/feature-spectrogram.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#ifndef KALDIFEAT_PYTHON_CSRC_FEATURE_SPECTROGRAM_H_
#define KALDIFEAT_PYTHON_CSRC_FEATURE_SPECTROGRAM_H_
#include "kaldifeat/python/csrc/kaldifeat.h"
namespace kaldifeat {
void PybindFeatureSpectrogram(py::module &m);
} // namespace kaldifeat
#endif // KALDIFEAT_PYTHON_CSRC_FEATURE_SPECTROGRAM_H_

View File

@ -0,0 +1,95 @@
// kaldifeat/python/csrc/feature-window.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#include "kaldifeat/python/csrc/feature-window.h"
#include <memory>
#include <string>
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/python/csrc/utils.h"
namespace kaldifeat {
static void PybindFrameExtractionOptions(py::module &m) {
using PyClass = FrameExtractionOptions;
py::class_<PyClass>(m, "FrameExtractionOptions")
.def(
py::init([](float samp_freq = 16000, float frame_shift_ms = 10.0f,
float frame_length_ms = 25.0f, float dither = 1.0f,
float preemph_coeff = 0.97f, bool remove_dc_offset = true,
const std::string &window_type = "povey",
bool round_to_power_of_two = true,
float blackman_coeff = 0.42f, bool snip_edges = true,
int32_t max_feature_vectors =
-1) -> std::unique_ptr<FrameExtractionOptions> {
auto opts = std::make_unique<FrameExtractionOptions>();
opts->samp_freq = samp_freq;
opts->frame_shift_ms = frame_shift_ms;
opts->frame_length_ms = frame_length_ms;
opts->dither = dither;
opts->preemph_coeff = preemph_coeff;
opts->remove_dc_offset = remove_dc_offset;
opts->window_type = window_type;
opts->round_to_power_of_two = round_to_power_of_two;
opts->blackman_coeff = blackman_coeff;
opts->snip_edges = snip_edges;
opts->max_feature_vectors = max_feature_vectors;
return opts;
}),
py::arg("samp_freq") = 16000, py::arg("frame_shift_ms") = 10.0f,
py::arg("frame_length_ms") = 25.0f, py::arg("dither") = 1.0f,
py::arg("preemph_coeff") = 0.97f, py::arg("remove_dc_offset") = true,
py::arg("window_type") = "povey",
py::arg("round_to_power_of_two") = true,
py::arg("blackman_coeff") = 0.42f, py::arg("snip_edges") = true,
py::arg("max_feature_vectors") = -1)
.def_readwrite("samp_freq", &PyClass::samp_freq)
.def_readwrite("frame_shift_ms", &PyClass::frame_shift_ms)
.def_readwrite("frame_length_ms", &PyClass::frame_length_ms)
.def_readwrite("dither", &PyClass::dither)
.def_readwrite("preemph_coeff", &PyClass::preemph_coeff)
.def_readwrite("remove_dc_offset", &PyClass::remove_dc_offset)
.def_readwrite("window_type", &PyClass::window_type)
.def_readwrite("round_to_power_of_two", &PyClass::round_to_power_of_two)
.def_readwrite("blackman_coeff", &PyClass::blackman_coeff)
.def_readwrite("snip_edges", &PyClass::snip_edges)
.def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)
.def("as_dict",
[](const PyClass &self) -> py::dict { return AsDict(self); })
.def_static("from_dict",
[](py::dict dict) -> PyClass {
return FrameExtractionOptionsFromDict(dict);
})
#if 0
.def_readwrite("allow_downsample",
&PyClass::allow_downsample)
.def_readwrite("allow_upsample", &PyClass::allow_upsample)
#endif
.def("__str__",
[](const PyClass &self) -> std::string { return self.ToString(); })
.def(py::pickle(
[](const PyClass &self) -> py::dict { return AsDict(self); },
[](py::dict dict) -> PyClass {
return FrameExtractionOptionsFromDict(dict);
}));
m.def("num_frames", &NumFrames, py::arg("num_samples"), py::arg("opts"),
py::arg("flush") = true);
m.def("get_strided", &GetStrided, py::arg("wave"), py::arg("opts"));
}
void PybindFeatureWindow(py::module &m) {
PybindFrameExtractionOptions(m);
m.def("num_frames", &NumFrames, py::arg("num_samples"), py::arg("opts"),
py::arg("flush") = true);
m.def("get_strided", &GetStrided, py::arg("wave"), py::arg("opts"));
}
} // namespace kaldifeat

View File

@ -0,0 +1,16 @@
// kaldifeat/python/csrc/feature-window.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#ifndef KALDIFEAT_PYTHON_CSRC_FEATURE_WINDOW_H_
#define KALDIFEAT_PYTHON_CSRC_FEATURE_WINDOW_H_
#include "kaldifeat/python/csrc/kaldifeat.h"
namespace kaldifeat {
void PybindFeatureWindow(py::module &m);
} // namespace kaldifeat
#endif // KALDIFEAT_PYTHON_CSRC_FEATURE_WINDOW_H_

View File

@ -0,0 +1,33 @@
// kaldifeat/python/csrc/kaldifeat.cc
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#include "kaldifeat/python/csrc/kaldifeat.h"
#include "kaldifeat/csrc/feature-fbank.h"
#include "kaldifeat/python/csrc/feature-fbank.h"
#include "kaldifeat/python/csrc/feature-mfcc.h"
#include "kaldifeat/python/csrc/feature-plp.h"
#include "kaldifeat/python/csrc/feature-spectrogram.h"
#include "kaldifeat/python/csrc/feature-window.h"
#include "kaldifeat/python/csrc/mel-computations.h"
#include "kaldifeat/python/csrc/online-feature.h"
#include "kaldifeat/python/csrc/whisper-fbank.h"
#include "torch/torch.h"
namespace kaldifeat {
PYBIND11_MODULE(_kaldifeat, m) {
m.doc() = "Python wrapper for kaldifeat";
PybindFeatureWindow(m);
PybindMelComputations(m);
PybindFeatureFbank(m);
PybindWhisperFbank(&m);
PybindFeatureMfcc(m);
PybindFeaturePlp(m);
PybindFeatureSpectrogram(m);
PybindOnlineFeature(m);
}
} // namespace kaldifeat

Some files were not shown because too many files have changed in this diff Show More